Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
71748805
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
71748805
编写于
10月 13, 2022
作者:
C
carryyu
提交者:
GitHub
10月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix softmax memory align (#46902)
上级
cf9ca61d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
93 addition
and
18 deletion
+93
-18
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
+93
-18
未找到文件。
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
浏览文件 @
71748805
...
...
@@ -346,28 +346,41 @@ template <template <typename, typename> class Reduction,
typename
AccT
,
int
VecSize
>
__device__
__forceinline__
AccT
ThreadVecReduce
(
const
T
*
data
,
ThreadVecReduce
(
T
*
data
,
int
dim_size
,
const
int
shift
,
const
Reduction
<
T
,
AccT
>&
functor
,
AccT
default_value
)
{
using
VecT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
AccT
thread_val
=
default_value
;
// for memory align, handle the unaligned data in first block.
int
offset
=
threadIdx
.
x
;
if
(
shift
>
0
)
{
data
-=
shift
;
dim_size
+=
shift
;
if
(
offset
>=
shift
)
{
thread_val
=
functor
(
thread_val
,
data
[
offset
]);
}
dim_size
-=
blockDim
.
x
;
data
+=
blockDim
.
x
;
}
const
int
last
=
dim_size
%
(
VecSize
*
blockDim
.
x
);
T
v
[
VecSize
];
VecT
*
value
=
reinterpret_cast
<
VecT
*>
(
&
v
);
for
(
int
offset
=
threadIdx
.
x
;
offset
*
VecSize
<
dim_size
-
last
;
offset
+=
blockDim
.
x
)
{
*
value
=
reinterpret_cast
<
VecT
*>
(
const_cast
<
T
*>
(
data
))[
offset
];
for
(;
offset
*
VecSize
<
dim_size
-
last
;
offset
+=
blockDim
.
x
)
{
*
value
=
reinterpret_cast
<
VecT
*>
(
data
)[
offset
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
thread_val
=
functor
(
thread_val
,
v
[
i
]);
}
}
for
(
int
offset
=
dim_size
-
last
+
threadIdx
.
x
;
offset
<
dim_size
;
offset
+=
blockDim
.
x
)
{
offset
=
dim_size
-
last
+
threadIdx
.
x
;
for
(;
offset
<
dim_size
;
offset
+=
blockDim
.
x
)
{
thread_val
=
functor
(
thread_val
,
data
[
offset
]);
}
return
thread_val
;
...
...
@@ -377,12 +390,27 @@ template <template <typename, typename> class Reduction,
typename
T
,
typename
AccT
,
int
VecSize
>
__device__
__forceinline__
void
ThreadVecWrite
(
T
*
out
,
const
T
*
input
,
int
dim_size
,
Reduction
<
AccT
,
T
>
functor
)
{
__device__
__forceinline__
void
ThreadVecWriteVec
(
T
*
out
,
T
*
input
,
int
dim_size
,
const
int
shift
,
Reduction
<
AccT
,
T
>
functor
)
{
using
VecT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
// for memory align, handle the unaligned data in first block.
int
offset
=
threadIdx
.
x
;
if
(
shift
>
0
)
{
input
-=
shift
;
out
-=
shift
;
dim_size
+=
shift
;
if
(
offset
>=
shift
)
{
out
[
offset
]
=
functor
(
static_cast
<
AccT
>
(
input
[
offset
]));
}
dim_size
-=
blockDim
.
x
;
input
+=
blockDim
.
x
;
out
+=
blockDim
.
x
;
}
const
int
last
=
dim_size
%
(
VecSize
*
blockDim
.
x
);
T
in_v
[
VecSize
];
...
...
@@ -391,9 +419,8 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
T
out_v
[
VecSize
];
VecT
*
out_value
=
reinterpret_cast
<
VecT
*>
(
&
out_v
);
for
(
int
offset
=
threadIdx
.
x
;
offset
*
VecSize
<
dim_size
-
last
;
offset
+=
blockDim
.
x
)
{
*
in_value
=
reinterpret_cast
<
VecT
*>
(
const_cast
<
T
*>
(
input
))[
offset
];
for
(;
offset
*
VecSize
<
dim_size
-
last
;
offset
+=
blockDim
.
x
)
{
*
in_value
=
reinterpret_cast
<
VecT
*>
(
input
)[
offset
];
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
out_v
[
i
]
=
functor
(
static_cast
<
AccT
>
(
in_v
[
i
]));
...
...
@@ -401,6 +428,33 @@ __device__ __forceinline__ void ThreadVecWrite(T* out,
reinterpret_cast
<
VecT
*>
(
out
)[
offset
]
=
*
out_value
;
}
offset
=
dim_size
-
last
+
threadIdx
.
x
;
// the tail
for
(;
offset
<
dim_size
;
offset
+=
blockDim
.
x
)
{
out
[
offset
]
=
functor
(
static_cast
<
AccT
>
(
input
[
offset
]));
}
}
template
<
template
<
typename
,
typename
>
class
Reduction
,
typename
T
,
typename
AccT
,
int
VecSize
>
__device__
__forceinline__
void
ThreadVecWrite
(
T
*
out
,
T
*
input
,
int
dim_size
,
Reduction
<
AccT
,
T
>
functor
)
{
const
int
last
=
dim_size
%
(
VecSize
*
blockDim
.
x
);
for
(
int
offset
=
threadIdx
.
x
;
offset
<
dim_size
-
last
;
offset
+=
blockDim
.
x
*
VecSize
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
out
[
offset
+
i
*
blockDim
.
x
]
=
functor
(
static_cast
<
AccT
>
(
input
[
offset
+
i
*
blockDim
.
x
]));
}
}
// the tail
for
(
int
offset
=
dim_size
-
last
+
threadIdx
.
x
;
offset
<
dim_size
;
offset
+=
blockDim
.
x
)
{
out
[
offset
]
=
functor
(
static_cast
<
AccT
>
(
input
[
offset
]));
...
...
@@ -417,13 +471,19 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
using
VecT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
int
bid
=
blockIdx
.
x
;
const
T
*
batch_input
=
src
+
bid
*
dim_size
;
T
*
batch_input
=
const_cast
<
T
*>
(
src
)
+
bid
*
dim_size
;
T
*
batch_output
=
softmax
+
bid
*
dim_size
;
const
int
input_align_shift
=
((
uint64_t
)
batch_input
)
%
MATRIX_SOFTMAX_ALIGN_BYTES
/
sizeof
(
T
);
const
int
output_align_shift
=
((
uint64_t
)
batch_output
)
%
MATRIX_SOFTMAX_ALIGN_BYTES
/
sizeof
(
T
);
// get max value
AccT
thread_max
=
ThreadVecReduce
<
MaxFunctor
,
T
,
AccT
,
VecSize
>
(
batch_input
,
dim_size
,
input_align_shift
,
MaxFunctor
<
T
,
AccT
>
(),
std
::
numeric_limits
<
AccT
>::
min
());
BlockReduceMax
<
AccT
>
(
&
thread_max
);
...
...
@@ -432,6 +492,7 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
AccT
thread_exp
=
ThreadVecReduce
<
SumExpFunctor
,
T
,
AccT
,
VecSize
>
(
batch_input
,
dim_size
,
input_align_shift
,
SumExpFunctor
<
T
,
AccT
>
(
thread_max
),
static_cast
<
AccT
>
(
0.
));
BlockReduceSum
<
AccT
>
(
&
thread_exp
);
...
...
@@ -440,12 +501,22 @@ __global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
if
(
LogMode
)
{
LogSoftmaxForwardFunctor
<
AccT
,
T
>
reduction
(
thread_max
,
std
::
log
(
thread_exp
));
ThreadVecWrite
<
LogSoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
reduction
);
if
(
input_align_shift
==
output_align_shift
)
{
ThreadVecWriteVec
<
LogSoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
input_align_shift
,
reduction
);
}
else
{
ThreadVecWrite
<
LogSoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
reduction
);
}
}
else
{
SoftmaxForwardFunctor
<
AccT
,
T
>
reduction
(
thread_max
,
thread_exp
);
ThreadVecWrite
<
SoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
reduction
);
if
(
input_align_shift
==
output_align_shift
)
{
ThreadVecWriteVec
<
SoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
input_align_shift
,
reduction
);
}
else
{
ThreadVecWrite
<
SoftmaxForwardFunctor
,
T
,
AccT
,
VecSize
>
(
batch_output
,
batch_input
,
dim_size
,
reduction
);
}
}
}
...
...
@@ -1371,5 +1442,9 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dev_ctx
,
dx_data
,
dout
.
data
<
T
>
(),
out
.
data
<
T
>
(),
N
,
dim
,
D
);
}
}
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
#undef FIXED_VEC_SIZE_BASE
#undef FIXED_VEC_SIZE
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录