Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
24103cbb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
24103cbb
编写于
2月 08, 2022
作者:
W
Wilber
提交者:
GitHub
2月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PTEN] Update gpu_context. (#39359)
* gpu_context.. * update * update * update
上级
0fee0044
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
280 addition
and
237 deletion
+280
-237
paddle/fluid/operators/conv_cudnn_helper.h
paddle/fluid/operators/conv_cudnn_helper.h
+0
-1
paddle/fluid/operators/math/im2col.cu
paddle/fluid/operators/math/im2col.cu
+36
-20
paddle/fluid/operators/math/vol2col.cu
paddle/fluid/operators/math/vol2col.cu
+159
-157
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+15
-4
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+2
-1
paddle/pten/backends/gpu/gpu_context.cc
paddle/pten/backends/gpu/gpu_context.cc
+18
-52
paddle/pten/backends/gpu/gpu_context.h
paddle/pten/backends/gpu/gpu_context.h
+50
-2
未找到文件。
paddle/fluid/operators/conv_cudnn_helper.h
浏览文件 @
24103cbb
...
@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
...
@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
auto
&
temp
=
ctx
.
cuda_device_context
();
AlgorithmsCache
<
algo_t
>&
algo_cache
=
AlgorithmsCache
<
algo_t
>&
algo_cache
=
*
(
framework
::
ConvSearchCache
::
Instance
().
GetForward
());
*
(
framework
::
ConvSearchCache
::
Instance
().
GetForward
());
...
...
paddle/fluid/operators/math/im2col.cu
浏览文件 @
24103cbb
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
...
@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
* col =
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
*/
template
<
class
T
>
template
<
class
DeviceContext
,
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
DeviceContext
,
platform
::
CUDADeviceContext
,
T
>
{
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
const
DataLayout
data_layout
)
{
...
@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
...
@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
* col =
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
*/
template
<
class
T
>
template
<
class
DeviceContext
,
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
DeviceContext
,
platform
::
CUDADeviceContext
,
T
>
{
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
...
@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform
::
CUDADeviceContext
,
float
>;
platform
::
CUDADeviceContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
double
>;
platform
::
CUDADeviceContext
,
double
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
float
>;
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
platform
::
CUDADeviceContext
,
double
>;
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
pten
::
GPUContext
,
double
>;
template
<
class
T
>
template
<
class
T
>
__global__
void
im2colOCF
(
const
T
*
im_data
,
int
im_channels
,
int
im_height
,
__global__
void
im2colOCF
(
const
T
*
im_data
,
int
im_channels
,
int
im_height
,
...
@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
...
@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
* col =
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
*/
template
<
class
T
>
template
<
class
DeviceContext
,
class
T
>
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
DeviceContext
,
platform
::
CUDADeviceContext
,
T
>
{
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
im
,
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
{
const
DataLayout
data_layout
)
{
...
@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
...
@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
* col =
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
*/
template
<
class
T
>
template
<
class
DeviceContext
,
class
T
>
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
DeviceContext
,
platform
::
CUDADeviceContext
,
T
>
{
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
im
,
...
@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
...
@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform
::
CUDADeviceContext
,
float
>;
platform
::
CUDADeviceContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
double
>;
platform
::
CUDADeviceContext
,
double
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
float
>;
template
class
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
float
>;
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
platform
::
CUDADeviceContext
,
double
>;
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
float
>;
template
class
Col2ImFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kOCF
,
pten
::
GPUContext
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/math/vol2col.cu
浏览文件 @
24103cbb
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
...
@@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
* output_depth, output_height, output_width]
*/
*/
template
<
class
T
>
// template <class DeviceContext,
class T>
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
// class Vol2ColFunctor
{
public:
//
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
template
<
class
DeviceContext
,
class
T
>
const
framework
::
Tensor
&
vol
,
void
Vol2ColFunctor
<
DeviceContext
,
T
>::
operator
()(
const
std
::
vector
<
int
>&
dilations
,
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
vol
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
col
,
const
DataLayout
data_layout
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
vol
.
dims
().
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
"The dimension of vol should be 4, but received %d."
,
vol
.
dims
().
size
()));
vol
.
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
PADDLE_ENFORCE_EQ
(
col
->
dims
().
size
(),
7
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
"The dimension of col should be 7, but received %d."
,
col
->
dims
().
size
()));
col
->
dims
().
size
()));
int
input_channels
=
int
input_channels
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
0
]
:
vol
.
dims
()[
3
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
0
]
:
vol
.
dims
()[
3
]);
int
input_depth
=
int
input_depth
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
1
]
:
vol
.
dims
()[
0
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
1
]
:
vol
.
dims
()[
0
]);
int
input_height
=
int
input_height
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
2
]
:
vol
.
dims
()[
1
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
2
]
:
vol
.
dims
()[
1
]);
int
input_width
=
int
input_width
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
3
]
:
vol
.
dims
()[
2
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
.
dims
()[
3
]
:
vol
.
dims
()[
2
]);
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_depth
=
col
->
dims
()[
1
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
2
];
int
filter_width
=
col
->
dims
()[
3
];
int
filter_width
=
col
->
dims
()[
3
];
int
output_depth
=
col
->
dims
()[
4
];
int
output_depth
=
col
->
dims
()[
4
];
int
output_height
=
col
->
dims
()[
5
];
int
output_height
=
col
->
dims
()[
5
];
int
output_width
=
col
->
dims
()[
6
];
int
output_width
=
col
->
dims
()[
6
];
bool
paddings_size_is_6
=
(
paddings
.
size
()
==
6
);
bool
paddings_size_is_6
=
(
paddings
.
size
()
==
6
);
int
pad_d_forth
=
paddings_size_is_6
?
paddings
[
0
]
:
paddings
[
0
];
int
pad_d_forth
=
paddings_size_is_6
?
paddings
[
0
]
:
paddings
[
0
];
int
pad_d_back
=
paddings_size_is_6
?
paddings
[
1
]
:
paddings
[
0
];
int
pad_d_back
=
paddings_size_is_6
?
paddings
[
1
]
:
paddings
[
0
];
int
pad_h_up
=
paddings_size_is_6
?
paddings
[
2
]
:
paddings
[
1
];
int
pad_h_up
=
paddings_size_is_6
?
paddings
[
2
]
:
paddings
[
1
];
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
auto
input_depth_tmp
=
(
input_depth
+
pad_d_forth
+
pad_d_back
-
auto
input_depth_tmp
=
(
input_depth
+
pad_d_forth
+
pad_d_back
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
strides
[
0
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
input_depth_tmp
,
output_depth
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
input_depth_tmp
,
output_depth
));
auto
input_height_tmp
=
(
input_height
+
pad_h_up
+
pad_h_down
-
auto
input_height_tmp
=
(
input_height
+
pad_h_up
+
pad_h_down
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
strides
[
1
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
input_height_tmp
,
output_height
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
input_height_tmp
,
output_height
));
auto
input_width_tmp
=
(
input_width
+
pad_w_left
+
pad_w_right
-
auto
input_width_tmp
=
(
input_width
+
pad_w_left
+
pad_w_right
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
strides
[
2
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
input_width_tmp
,
output_width
,
"input_width(%d) and output_width(%d) are mismatching."
,
platform
::
errors
::
InvalidArgument
(
input_width_tmp
,
output_width
));
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
int
num_outputs
=
int
num_outputs
=
input_channels
*
output_depth
*
output_height
*
output_width
;
input_channels
*
output_depth
*
output_height
*
output_width
;
int
max_threads
=
1024
;
int
max_threads
=
1024
;
#ifdef WITH_NV_JETSON
#ifdef WITH_NV_JETSON
platform
::
ChangeThreadNum
(
context
,
&
max_threads
);
platform
::
ChangeThreadNum
(
context
,
&
max_threads
);
#endif
#endif
const
int
threads
=
max_threads
;
const
int
threads
=
max_threads
;
const
int
blocks
=
(
num_outputs
+
max_threads
-
1
)
/
max_threads
;
const
int
blocks
=
(
num_outputs
+
max_threads
-
1
)
/
max_threads
;
vol2col
<
T
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
vol2col
<
T
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
num_outputs
,
vol
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
(),
pad_w_left
,
output_depth
,
output_height
,
output_width
,
col
->
data
<
T
>
(),
data_layout
);
data_layout
);
}
}
};
//
};
template
<
class
T
>
template
<
class
T
>
__global__
void
col2vol
(
int
num_kernels
,
const
T
*
data_col
,
int
depth
,
__global__
void
col2vol
(
int
num_kernels
,
const
T
*
data_col
,
int
depth
,
...
@@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
...
@@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
* output_depth, output_height, output_width]
*/
*/
template
<
class
T
>
// template <class DeviceContext,
class T>
class
Col2VolFunctor
<
platform
::
CUDA
DeviceContext
,
T
>
{
// class Col2VolFunctor<
DeviceContext, T> {
public:
//
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
template
<
class
DeviceContext
,
class
T
>
const
framework
::
Tensor
&
col
,
void
Col2VolFunctor
<
DeviceContext
,
T
>::
operator
()(
const
std
::
vector
<
int
>&
dilations
,
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
col
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
const
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
vol
,
const
DataLayout
data_layout
)
const
{
const
DataLayout
data_layout
)
const
{
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
PADDLE_ENFORCE_EQ
(
vol
->
dims
().
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimension of vol should be 4, but received %d."
,
"The dimension of vol should be 4, but received %d."
,
vol
->
dims
().
size
()));
vol
->
dims
().
size
()));
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
PADDLE_ENFORCE_EQ
(
col
.
dims
().
size
(),
7
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimension of col should be 7, but received %d."
,
"The dimension of col should be 7, but received %d."
,
col
.
dims
().
size
()));
col
.
dims
().
size
()));
int
input_channels
=
int
input_channels
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
0
]
:
vol
->
dims
()[
3
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
0
]
:
vol
->
dims
()[
3
]);
int
input_depth
=
int
input_depth
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
1
]
:
vol
->
dims
()[
0
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
1
]
:
vol
->
dims
()[
0
]);
int
input_height
=
int
input_height
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
2
]
:
vol
->
dims
()[
1
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
2
]
:
vol
->
dims
()[
1
]);
int
input_width
=
int
input_width
=
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
3
]
:
vol
->
dims
()[
2
]);
(
data_layout
!=
DataLayout
::
kNHWC
?
vol
->
dims
()[
3
]
:
vol
->
dims
()[
2
]);
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_depth
=
col
.
dims
()[
1
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_height
=
col
.
dims
()[
2
];
int
filter_width
=
col
.
dims
()[
3
];
int
filter_width
=
col
.
dims
()[
3
];
int
output_depth
=
col
.
dims
()[
4
];
int
output_depth
=
col
.
dims
()[
4
];
int
output_height
=
col
.
dims
()[
5
];
int
output_height
=
col
.
dims
()[
5
];
int
output_width
=
col
.
dims
()[
6
];
int
output_width
=
col
.
dims
()[
6
];
bool
paddings_size_is_6
=
(
paddings
.
size
()
==
6
);
bool
paddings_size_is_6
=
(
paddings
.
size
()
==
6
);
int
pad_d_forth
=
paddings_size_is_6
?
paddings
[
0
]
:
paddings
[
0
];
int
pad_d_forth
=
paddings_size_is_6
?
paddings
[
0
]
:
paddings
[
0
];
int
pad_d_back
=
paddings_size_is_6
?
paddings
[
1
]
:
paddings
[
0
];
int
pad_d_back
=
paddings_size_is_6
?
paddings
[
1
]
:
paddings
[
0
];
int
pad_h_up
=
paddings_size_is_6
?
paddings
[
2
]
:
paddings
[
1
];
int
pad_h_up
=
paddings_size_is_6
?
paddings
[
2
]
:
paddings
[
1
];
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_h_down
=
paddings_size_is_6
?
paddings
[
3
]
:
paddings
[
1
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_left
=
paddings_size_is_6
?
paddings
[
4
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
int
pad_w_right
=
paddings_size_is_6
?
paddings
[
5
]
:
paddings
[
2
];
auto
input_depth_tmp
=
(
input_depth
+
pad_d_forth
+
pad_d_back
-
auto
input_depth_tmp
=
(
input_depth
+
pad_d_forth
+
pad_d_back
-
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
((
dilations
[
0
]
*
(
filter_depth
-
1
)
+
1
)))
/
strides
[
0
]
+
strides
[
0
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input_depth_tmp
,
output_depth
,
input_depth_tmp
,
output_depth
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"input_depth(%d) and output_depth(%d) are mismatching."
,
"input_depth(%d) and output_depth(%d) are mismatching."
,
input_depth_tmp
,
output_depth
));
input_depth_tmp
,
output_depth
));
auto
input_height_tmp
=
(
input_height
+
pad_h_up
+
pad_h_down
-
auto
input_height_tmp
=
(
input_height
+
pad_h_up
+
pad_h_down
-
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
((
dilations
[
1
]
*
(
filter_height
-
1
)
+
1
)))
/
strides
[
1
]
+
strides
[
1
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input_height_tmp
,
output_height
,
input_height_tmp
,
output_height
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"input_height(%d) and output_height(%d) are mismatching."
,
"input_height(%d) and output_height(%d) are mismatching."
,
input_height_tmp
,
output_height
));
input_height_tmp
,
output_height
));
auto
input_width_tmp
=
(
input_width
+
pad_w_left
+
pad_w_right
-
auto
input_width_tmp
=
(
input_width
+
pad_w_left
+
pad_w_right
-
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
((
dilations
[
2
]
*
(
filter_width
-
1
)
+
1
)))
/
strides
[
2
]
+
strides
[
2
]
+
1
;
1
;
PADDLE_ENFORCE_EQ
(
input_width_tmp
,
output_width
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
input_width_tmp
,
output_width
,
"input_width(%d) and output_width(%d) are mismatching."
,
platform
::
errors
::
InvalidArgument
(
input_width_tmp
,
output_width
));
"input_width(%d) and output_width(%d) are mismatching."
,
input_width_tmp
,
output_width
));
int
num_kernels
=
input_channels
*
input_depth
*
input_height
*
input_width
;
int
num_kernels
=
input_channels
*
input_depth
*
input_height
*
input_width
;
int
max_threads
=
1024
;
int
max_threads
=
1024
;
#ifdef WITH_NV_JETSON
#ifdef WITH_NV_JETSON
platform
::
ChangeThreadNum
(
context
,
&
max_threads
);
platform
::
ChangeThreadNum
(
context
,
&
max_threads
);
#endif
#endif
const
int
threads
=
max_threads
;
const
int
threads
=
max_threads
;
const
int
blocks
=
(
num_kernels
+
max_threads
-
1
)
/
max_threads
;
const
int
blocks
=
(
num_kernels
+
max_threads
-
1
)
/
max_threads
;
col2vol
<
T
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
col2vol
<
T
><<<
blocks
,
threads
,
0
,
context
.
stream
()
>>>
(
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
num_kernels
,
col
.
data
<
T
>
(),
input_depth
,
input_height
,
input_width
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
dilations
[
0
],
dilations
[
1
],
dilations
[
2
],
filter_depth
,
filter_height
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
filter_width
,
strides
[
0
],
strides
[
1
],
strides
[
2
],
pad_d_forth
,
pad_h_up
,
pad_w_left
,
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
(),
pad_w_left
,
output_depth
,
output_height
,
output_width
,
vol
->
data
<
T
>
(),
data_layout
);
data_layout
);
}
}
};
//
};
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Vol2ColFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Vol2ColFunctor
<
pten
::
GPUContext
,
float
>;
template
class
Vol2ColFunctor
<
pten
::
GPUContext
,
double
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2VolFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
Col2VolFunctor
<
pten
::
GPUContext
,
float
>;
template
class
Col2VolFunctor
<
pten
::
GPUContext
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
24103cbb
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/allocator.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
...
@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() {
...
@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
pten
::
GPUContext
(
place
)
{
:
pten
::
GPUContext
(
place
)
{
pten
::
GPUContext
::
PartialInitWithoutAllocator
();
pten
::
GPUContext
::
PartialInitWithoutAllocator
();
cuda_stream_
.
reset
(
cuda_stream_
.
reset
(
new
stream
::
CUDAStream
(
pten
::
GPUContext
::
stream
(),
place
));
new
stream
::
CUDAStream
(
pten
::
GPUContext
::
stream
(),
this
->
GetPlace
()));
workspace_
.
reset
(
new
pten
::
DnnWorkspaceHandle
(
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
place
,
pten
::
GPUContext
::
stream
())
.
get
()));
}
}
CUDADeviceContext
::~
CUDADeviceContext
()
=
default
;
CUDADeviceContext
::~
CUDADeviceContext
()
=
default
;
...
@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const {
...
@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const {
pten
::
GPUContext
::
WaitStreamCallback
();
pten
::
GPUContext
::
WaitStreamCallback
();
}
}
CudnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
pten
::
DnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
return
CudnnWorkspaceHandle
(
*
this
,
&
cudnn_handle_mtx_
);
if
(
thread_ctx_
.
count
(
this
))
{
// return workspace_.get();
return
pten
::
DnnWorkspaceHandle
(
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
GetPlace
(),
pten
::
GPUContext
::
stream
())
.
get
());
}
return
pten
::
GPUContext
::
cudnn_workspace_handle
();
}
}
gpuStream_t
CUDADeviceContext
::
stream
()
const
{
gpuStream_t
CUDADeviceContext
::
stream
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
24103cbb
...
@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext {
...
@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext {
* workspace. Once the handle is destructed, the lock would be released.
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
* sequential cudnn function calls. */
Cud
nnWorkspaceHandle
cudnn_workspace_handle
()
const
;
pten
::
D
nnWorkspaceHandle
cudnn_workspace_handle
()
const
;
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
gpuStream_t
stream
()
const
;
gpuStream_t
stream
()
const
;
...
@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext {
...
@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext {
// NOTE: Just for compatibility with the past, please delete if there is an
// NOTE: Just for compatibility with the past, please delete if there is an
// elegant way.
// elegant way.
std
::
unique_ptr
<
stream
::
CUDAStream
>
cuda_stream_
;
std
::
unique_ptr
<
stream
::
CUDAStream
>
cuda_stream_
;
std
::
unique_ptr
<
pten
::
DnnWorkspaceHandle
>
workspace_
{
nullptr
};
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
);
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
);
};
};
...
...
paddle/pten/backends/gpu/gpu_context.cc
浏览文件 @
24103cbb
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
#include <array>
#include <functional>
#include <functional>
#include <future>
#include <future>
...
@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream,
...
@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream,
}
// namespace internal
}
// namespace internal
class
DnnWorkspaceHandle
{
void
DnnWorkspaceHandle
::
ResetWorkspace
()
{
allocation_
=
nullptr
;
}
public:
explicit
inline
DnnWorkspaceHandle
(
Allocator
*
allocator
)
:
allocator_
(
allocator
)
{}
inline
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
>
WorkspaceSize
())
{
ReallocWorkspace
(
required_workspace_bytes
);
}
VLOG
(
2
)
<<
"Cudnn workspace size at RunFunc: "
<<
static_cast
<
double
>
(
WorkspaceSize
())
/
(
1
<<
20
)
<<
" MB"
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
cudnn_func
(
allocation_
?
allocation_
->
ptr
()
:
nullptr
);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline
void
RunFuncSync
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
RunFunc
(
cudnn_func
,
required_workspace_bytes
);
ResetWorkspace
();
}
inline
size_t
WorkspaceSize
()
{
void
DnnWorkspaceHandle
::
ReallocWorkspace
(
size_t
required_workspace_bytes
)
{
if
(
allocation_
==
nullptr
)
{
if
(
required_workspace_bytes
<=
WorkspaceSize
())
return
;
return
0
;
// reset allocation first before re-allocate to save memory
}
allocation_
.
reset
();
return
allocation_
->
size
();
allocation_
=
allocator_
->
Allocate
(
required_workspace_bytes
);
}
}
void
ResetWorkspace
()
{
allocation_
=
nullptr
;
}
void
ReallocWorkspace
(
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
<=
WorkspaceSize
())
return
;
// reset allocation first before re-allocate to save memory
allocation_
.
reset
();
allocation_
=
allocator_
->
Allocate
(
required_workspace_bytes
);
}
private:
Allocator
::
AllocationPtr
allocation_
{
nullptr
};
Allocator
*
allocator_
{
nullptr
};
std
::
mutex
mtx_
;
};
struct
GPUContext
::
Impl
{
struct
GPUContext
::
Impl
{
void
Init
()
{
void
Init
()
{
...
@@ -341,9 +301,15 @@ struct GPUContext::Impl {
...
@@ -341,9 +301,15 @@ struct GPUContext::Impl {
}
}
}
}
DnnWorkspaceHandle
*
GetDnnWorkspace
()
{
// TODO(wilber): The return type is a pointer, to be modified later.
PD_CHECK
(
workspace_
!=
nullptr
,
"the gpu cudnn workspace is nullptr."
);
// DnnWorkspaceHandle* GetDnnWorkspace() {
return
workspace_
;
// PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
// return workspace_;
// }
DnnWorkspaceHandle
GetDnnWorkspace
()
{
PD_CHECK
(
allocator_
!=
nullptr
,
"the device allocator for gpu context is nullptr."
);
return
DnnWorkspaceHandle
(
allocator_
);
}
}
void
InitStream
()
{
void
InitStream
()
{
...
@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const {
...
@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const {
return
impl_
->
eigen_device
();
return
impl_
->
eigen_device
();
}
}
DnnWorkspaceHandle
*
GPUContext
::
cudnn_workspace_handle
()
{
DnnWorkspaceHandle
GPUContext
::
cudnn_workspace_handle
()
const
{
return
impl_
->
GetDnnWorkspace
();
return
impl_
->
GetDnnWorkspace
();
}
}
...
...
paddle/pten/backends/gpu/gpu_context.h
浏览文件 @
24103cbb
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <array>
#include <array>
#include <functional>
#include <functional>
#include <mutex>
#include "paddle/pten/backends/gpu/forwards.h"
#include "paddle/pten/backends/gpu/forwards.h"
#include "paddle/pten/backends/gpu/gpu_decls.h"
#include "paddle/pten/backends/gpu/gpu_decls.h"
#include "paddle/pten/backends/gpu/gpu_helper.h"
#include "paddle/pten/backends/gpu/gpu_helper.h"
...
@@ -24,7 +25,53 @@ limitations under the License. */
...
@@ -24,7 +25,53 @@ limitations under the License. */
namespace
pten
{
namespace
pten
{
class
DnnWorkspaceHandle
;
class
DnnWorkspaceHandle
{
public:
explicit
inline
DnnWorkspaceHandle
(
Allocator
*
allocator
)
:
allocator_
(
allocator
)
{
mtx_
.
reset
(
new
std
::
mutex
());
}
inline
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
if
(
required_workspace_bytes
>
WorkspaceSize
())
{
ReallocWorkspace
(
required_workspace_bytes
);
}
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mtx_
);
cudnn_func
(
allocation_
?
allocation_
->
ptr
()
:
nullptr
);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline
void
RunFuncSync
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_bytes
)
{
RunFunc
(
cudnn_func
,
required_workspace_bytes
);
ResetWorkspace
();
}
inline
size_t
WorkspaceSize
()
{
if
(
allocation_
==
nullptr
)
{
return
0
;
}
return
allocation_
->
size
();
}
void
ResetWorkspace
();
void
ReallocWorkspace
(
size_t
required_workspace_bytes
);
DnnWorkspaceHandle
(
DnnWorkspaceHandle
&&
)
=
default
;
DnnWorkspaceHandle
&
operator
=
(
DnnWorkspaceHandle
&&
)
=
delete
;
private:
Allocator
::
AllocationPtr
allocation_
{
nullptr
};
Allocator
*
allocator_
{
nullptr
};
std
::
unique_ptr
<
std
::
mutex
>
mtx_
;
};
class
GPUContext
:
public
DeviceContext
{
class
GPUContext
:
public
DeviceContext
{
public:
public:
...
@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext {
...
@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext {
* would be acquired to prevent other threads from accessing the
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* workspace. Once the handle is destructed, the lock would be released.
*/
*/
DnnWorkspaceHandle
*
cudnn_workspace_handle
();
// TODO(wilber): The return type is a pointer, to be modified later.
DnnWorkspaceHandle
cudnn_workspace_handle
()
const
;
public:
public:
/*! \brief Call cublas function safely. */
/*! \brief Call cublas function safely. */
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录