Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
face8f1f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
face8f1f
编写于
9月 21, 2022
作者:
R
RichardWooSJTU
提交者:
GitHub
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multihead_matmul nan error when seq len et 1024 (#46286)
上级
23e06680
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
12 deletion
+12
-12
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+12
-12
未找到文件。
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
face8f1f
...
@@ -378,7 +378,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
...
@@ -378,7 +378,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
32
==
0
);
T
stride_max
=
-
1e20
f
;
T
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
stride_max
=
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
stride_max
=
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
>
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
>
stride_max
stride_max
...
@@ -389,13 +389,13 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
...
@@ -389,13 +389,13 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
T
max_val
=
phi
::
funcs
::
blockReduceMax
<
T
>
(
stride_max
,
mask
);
T
max_val
=
phi
::
funcs
::
blockReduceMax
<
T
>
(
stride_max
,
mask
);
T
stride_sum
=
0.
f
;
T
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
stride_sum
+=
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
stride_sum
+=
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
);
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
);
}
}
T
sum_val
=
phi
::
funcs
::
blockReduceSum
<
T
>
(
stride_sum
,
mask
);
T
sum_val
=
phi
::
funcs
::
blockReduceSum
<
T
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
=
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
=
(
T
)(
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
(
T
)(
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
)
/
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
)
/
...
@@ -417,7 +417,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
...
@@ -417,7 +417,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
32
==
0
);
float
stride_max
=
-
1e20
f
;
float
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
=
tmp
>
stride_max
?
tmp
:
stride_max
;
stride_max
=
tmp
>
stride_max
?
tmp
:
stride_max
;
...
@@ -425,14 +425,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
...
@@ -425,14 +425,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
float
max_val
=
phi
::
funcs
::
blockReduceMax
<
float
>
(
stride_max
,
mask
);
float
max_val
=
phi
::
funcs
::
blockReduceMax
<
float
>
(
stride_max
,
mask
);
float
stride_sum
=
0.
f
;
float
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
float
tmp
=
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
+=
__expf
(
tmp
-
max_val
);
stride_sum
+=
__expf
(
tmp
-
max_val
);
}
}
float
sum_val
=
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
,
mask
);
float
sum_val
=
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
float
tmp
=
__expf
(
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
__expf
(
static_cast
<
float
>
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
])
-
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
])
-
...
@@ -454,7 +454,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
...
@@ -454,7 +454,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
32
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
.
x
=
max
(
stride_max
.
x
,
cur
.
x
);
stride_max
.
x
=
max
(
stride_max
.
x
,
cur
.
x
);
...
@@ -464,7 +464,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
...
@@ -464,7 +464,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi
::
funcs
::
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
phi
::
funcs
::
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
.
x
+=
__expf
(
cur
.
x
-
max_val
);
stride_sum
.
x
+=
__expf
(
cur
.
x
-
max_val
);
...
@@ -475,7 +475,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
...
@@ -475,7 +475,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
1e-6
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
float2
cur
=
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
=
phi
::
funcs
::
FloatsToPair
<
T
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
=
phi
::
funcs
::
FloatsToPair
<
T
>
(
...
@@ -499,7 +499,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
...
@@ -499,7 +499,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
32
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
float2
cur
=
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
...
@@ -510,7 +510,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
...
@@ -510,7 +510,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi
::
funcs
::
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
phi
::
funcs
::
blockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
float2
cur
=
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
...
@@ -522,7 +522,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
...
@@ -522,7 +522,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
phi
::
funcs
::
blockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
1e-6
f
;
for
(
int
i
=
0
;
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float2
cur
=
float2
cur
=
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
bias_qk_
[
threadIdx
.
x
+
i
+
qk_offset
]);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录