Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
afb13484
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
afb13484
编写于
11月 26, 2019
作者:
Z
zhaoyuchen2018
提交者:
GitHub
11月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix ernie python infer diff (#21311)
* Fix ernie pythoin infer diff * Refine mask test=develop
上级
b6ce4f8b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
16 deletion
+24
-16
paddle/fluid/operators/fused/multihead_matmul_op.cu
paddle/fluid/operators/fused/multihead_matmul_op.cu
+24
-16
未找到文件。
paddle/fluid/operators/fused/multihead_matmul_op.cu
浏览文件 @
afb13484
...
...
@@ -28,10 +28,10 @@ namespace operators {
#define WARP_SIZE 32
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
__inline__
__device__
T
warpReduceSum
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
warpSize
);
val
+=
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
);
#else
val
+=
__shfl_xor
(
val
,
mask
,
warpSize
);
#endif
...
...
@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
__inline__
__device__
T
blockReduceSum
(
T
val
,
unsigned
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceSum
<
T
>
(
val
);
val
=
warpReduceSum
<
T
>
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
>>
5
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
);
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
threadIdx
.
x
<
block_span
)
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
,
mask
);
return
val
;
}
template
<
typename
T
>
__inline__
__device__
T
warpReduceMax
(
T
val
)
{
__inline__
__device__
T
warpReduceMax
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val
=
max
(
val
,
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
warpSize
));
val
=
max
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
#else
val
=
max
(
val
,
__shfl_xor
(
val
,
mask
,
warpSize
));
#endif
...
...
@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
/* Calculate the maximum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceMax
(
T
val
)
{
__inline__
__device__
T
blockReduceMax
(
T
val
,
unsigned
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceMax
(
val
);
val
=
warpReduceMax
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
>>
5
))
?
shared
[
lane
]
:
-
1e10
f
;
val
=
warpReduceMax
(
val
);
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
threadIdx
.
x
<
block_span
)
?
shared
[
lane
]
:
-
1e10
f
;
val
=
warpReduceMax
(
val
,
mask
);
return
val
;
}
...
...
@@ -190,7 +194,8 @@ template <typename T>
__global__
void
softmax_kernel_with_eltadd
(
T
*
qk_buf_
,
const
T
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
)
{
const
int
seq_len
,
const
unsigned
mask
)
{
int
seq_id
=
blockIdx
.
x
%
seq_len
;
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
int
bias_offset
=
blockIdx
.
x
%
(
head_num
*
seq_len
)
*
seq_len
;
...
...
@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
bias_qk_
[
threadIdx
.
x
+
bias_offset
]))
:
0.0
f
;
float
tmp
=
threadIdx
.
x
<
seq_len
?
static_cast
<
float
>
(
qk
)
:
-
1e20
f
;
float
max_val
=
blockReduceMax
<
float
>
(
tmp
);
float
max_val
=
blockReduceMax
<
float
>
(
tmp
,
mask
);
if
(
threadIdx
.
x
==
0
)
s_max
=
max_val
;
__syncthreads
();
float
qk_tmp
=
threadIdx
.
x
<
seq_len
?
__expf
(
static_cast
<
float
>
(
tmp
-
s_max
))
:
0.0
f
;
float
sum_val
=
blockReduceSum
<
float
>
(
qk_tmp
);
float
sum_val
=
blockReduceSum
<
float
>
(
qk_tmp
,
mask
);
if
(
threadIdx
.
x
==
0
)
{
s_sum
=
sum_val
+
1e-6
f
;
...
...
@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int
grid
=
m
;
int
block
=
k
;
unsigned
mask
=
block
<
32
?
(((
unsigned
)
1
<<
block
)
-
1
)
:
FINAL_MASK
;
softmax_kernel_with_eltadd
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
);
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
,
mask
);
}
template
<
typename
T
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录