Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bd5e97d3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bd5e97d3
编写于
6月 21, 2022
作者:
Z
Zhang Ting
提交者:
GitHub
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
slice large tensor for cudnn_softmax (#43681)
上级
827d9992
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
182 addition
and
124 deletion
+182
-124
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
+182
-124
未找到文件。
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
浏览文件 @
bd5e97d3
...
@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
...
@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
template
<
typename
T
>
template
<
typename
T
>
void
SoftmaxForwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
void
SoftmaxForwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
T
*
x_data
,
const
int
axis
,
const
int
axis
,
const
int
rank
,
const
bool
log_mode
,
const
bool
log_mode
,
DenseTensor
*
out
)
{
const
std
::
vector
<
int
>&
tensor_dims
,
auto
*
out_data
=
out
->
data
<
T
>
();
T
*
out_data
)
{
const
int
rank
=
x
.
dims
().
size
();
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
x
.
dims
(),
axis
);
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
...
@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
...
@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle
,
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
desc
,
x
.
data
<
T
>
()
,
x
_data
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
desc
,
out_data
,
out_data
,
...
@@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
...
@@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
desc
,
x
.
data
<
T
>
()
,
x
_data
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
desc
,
out_data
));
out_data
));
...
@@ -820,17 +817,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
...
@@ -820,17 +817,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SoftmaxBackwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
void
LaunchSoftmaxForwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
x
,
const
DenseTensor
&
dout
,
const
int
axis
,
const
int
axis
,
const
bool
log_mode
,
const
bool
log_mode
,
DenseTensor
*
dx
)
{
DenseTensor
*
out
)
{
auto
*
dx_data
=
dx
->
data
<
T
>
();
auto
*
out_data
=
out
->
data
<
T
>
();
auto
*
x_data
=
x
.
data
<
T
>
();
const
int
rank
=
x
.
dims
().
size
();
int
rank
=
out
.
dims
().
size
();
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
x
.
dims
(),
axis
);
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
out
.
dims
(),
axis
);
int64_t
remaining
=
tensor_dims
[
0
];
int
dim
=
tensor_dims
[
1
];
int64_t
batch_size
=
std
::
numeric_limits
<
int32_t
>::
max
()
/
dim
;
int
offset
=
batch_size
*
dim
;
while
(
remaining
>
0
)
{
tensor_dims
[
0
]
=
std
::
min
<
int64_t
>
(
remaining
,
batch_size
);
SoftmaxForwardCudnnKernel
<
T
>
(
dev_ctx
,
x_data
,
axis
,
rank
,
log_mode
,
tensor_dims
,
out_data
);
x_data
+=
offset
;
out_data
+=
offset
;
remaining
-=
batch_size
;
}
}
template
<
typename
T
>
void
SoftmaxBackwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
T
*
out_data
,
const
T
*
dout_data
,
const
int
axis
,
const
int
rank
,
const
bool
log_mode
,
const
std
::
vector
<
int
>&
tensor_dims
,
T
*
dx_data
)
{
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
GPUDNNDataLayout
layout
=
GPUDNNDataLayout
::
kNCHW
;
...
@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
...
@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle
,
handle
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
desc
,
out
.
data
<
T
>
()
,
out
_data
,
desc
,
desc
,
dout
.
data
<
T
>
()
,
dout
_data
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
desc
,
dx_data
,
dx_data
,
...
@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
...
@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode
,
mode
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kOne
(),
desc
,
desc
,
out
.
data
<
T
>
()
,
out
_data
,
desc
,
desc
,
dout
.
data
<
T
>
()
,
dout
_data
,
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
paddle
::
platform
::
CudnnDataType
<
T
>::
kZero
(),
desc
,
desc
,
dx_data
));
dx_data
));
#endif
#endif
}
}
template
<
typename
T
>
void
LaunchSoftmaxBackwardCudnnKernel
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
const
int
axis
,
const
bool
log_mode
,
DenseTensor
*
dx
)
{
auto
*
dx_data
=
dx
->
data
<
T
>
();
auto
*
out_data
=
out
.
data
<
T
>
();
auto
*
dout_data
=
dout
.
data
<
T
>
();
int
rank
=
out
.
dims
().
size
();
std
::
vector
<
int
>
tensor_dims
=
GetSoftmaxTensorDims
(
out
.
dims
(),
axis
);
int64_t
remaining
=
tensor_dims
[
0
];
int
dim
=
tensor_dims
[
1
];
int64_t
batch_size
=
std
::
numeric_limits
<
int32_t
>::
max
()
/
dim
;
int
offset
=
batch_size
*
dim
;
while
(
remaining
>
0
)
{
tensor_dims
[
0
]
=
std
::
min
<
int64_t
>
(
remaining
,
batch_size
);
SoftmaxBackwardCudnnKernel
<
T
>
(
dev_ctx
,
out_data
,
dout_data
,
axis
,
rank
,
log_mode
,
tensor_dims
,
dx_data
);
out_data
+=
offset
;
dout_data
+=
offset
;
dx_data
+=
offset
;
remaining
-=
batch_size
;
}
}
#if CUDNN_VERSION < 8100
#if CUDNN_VERSION < 8100
template
<
>
template
<
>
inline
void
SoftmaxForwardCudnnKernel
<
phi
::
dtype
::
bfloat16
>
(
inline
void
Launch
SoftmaxForwardCudnnKernel
<
phi
::
dtype
::
bfloat16
>
(
const
GPUContext
&
dev_ctx
,
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
int
axis
,
const
int
axis
,
...
@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
...
@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
"8100."
));
"8100."
));
}
}
template
<
>
template
<
>
inline
void
SoftmaxBackwardCudnnKernel
<
phi
::
dtype
::
bfloat16
>
(
inline
void
Launch
SoftmaxBackwardCudnnKernel
<
phi
::
dtype
::
bfloat16
>
(
const
GPUContext
&
dev_ctx
,
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
const
DenseTensor
&
dout
,
...
@@ -933,7 +986,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
...
@@ -933,7 +986,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int
dim
=
tensor_dims
[
1
];
int
dim
=
tensor_dims
[
1
];
int
D
=
tensor_dims
[
2
];
int
D
=
tensor_dims
[
2
];
if
(
D
==
1
&&
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
if
(
D
==
1
)
{
if
(
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
int
dim_log2
=
static_cast
<
int
>
(
Log2Ceil
(
dim
));
int
dim_log2
=
static_cast
<
int
>
(
Log2Ceil
(
dim
));
int
dim_ceil
=
1
<<
dim_log2
;
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
...
@@ -982,11 +1036,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
...
@@ -982,11 +1036,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim
,
dim
,
dim_log2
);
dim_log2
);
}
}
}
else
if
(
D
>
1
)
{
}
else
{
LaunchSoftmaxForwardCudnnKernel
<
T
>
(
dev_ctx
,
x
,
axis
,
LogMode
,
out
);
}
}
else
{
LaunchNormalSoftmaxForward
<
T
,
LogMode
>
(
LaunchNormalSoftmaxForward
<
T
,
LogMode
>
(
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
D
);
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
D
);
}
else
{
SoftmaxForwardCudnnKernel
<
T
>
(
dev_ctx
,
x
,
axis
,
LogMode
,
out
);
}
}
}
}
...
@@ -1005,7 +1060,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
...
@@ -1005,7 +1060,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int
dim
=
tensor_dims
[
1
];
int
dim
=
tensor_dims
[
1
];
int
D
=
tensor_dims
[
2
];
int
D
=
tensor_dims
[
2
];
if
(
D
==
1
&&
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
if
(
D
==
1
)
{
if
(
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
int
dim_log2
=
Log2Ceil
(
dim
);
int
dim_log2
=
Log2Ceil
(
dim
);
int
dim_ceil
=
1
<<
dim_log2
;
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
...
@@ -1055,11 +1111,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
...
@@ -1055,11 +1111,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim
,
dim
,
dim_log2
);
dim_log2
);
}
}
}
else
if
(
D
>
1
)
{
}
else
{
LaunchSoftmaxBackwardCudnnKernel
<
T
>
(
dev_ctx
,
out
,
dout
,
axis
,
LogMode
,
dx
);
}
}
else
{
LaunchNormalSoftmaxBackward
<
T
,
LogMode
>
(
LaunchNormalSoftmaxBackward
<
T
,
LogMode
>
(
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
D
);
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
D
);
}
else
{
SoftmaxBackwardCudnnKernel
<
T
>
(
dev_ctx
,
out
,
dout
,
axis
,
LogMode
,
dx
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录