Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d7f4599d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7f4599d
编写于
7月 06, 2022
作者:
L
LiYuRio
提交者:
GitHub
7月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix nan in fused multi transformer (#44093)
上级
54a9daf2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
64 addition
and
7 deletion
+64
-7
paddle/fluid/distributed/store/tcp_store.cc
paddle/fluid/distributed/store/tcp_store.cc
+4
-1
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+60
-6
未找到文件。
paddle/fluid/distributed/store/tcp_store.cc
浏览文件 @
d7f4599d
...
...
@@ -125,7 +125,10 @@ void MasterDaemon::CloseControlFd() {
void
MasterDaemon
::
StopByControlFd
()
{
VLOG
(
4
)
<<
(
"begin to run StopByControlFd"
);
if
(
_control_fd
[
1
]
!=
-
1
)
{
::
write
(
_control_fd
[
1
],
"
\0
"
,
1
);
PADDLE_ENFORCE_NE
(
::
write
(
_control_fd
[
1
],
"
\0
"
,
1
),
-
1
,
platform
::
errors
::
Fatal
(
"failed to write control pipe errno:%d"
,
errno
));
// close the write end of the pipe
::
close
(
_control_fd
[
1
]);
_control_fd
[
1
]
=
-
1
;
...
...
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
d7f4599d
...
...
@@ -294,6 +294,52 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
float
b
)
{
float2
tmp
=
half2_to_float2
(
a
);
float2
tmp_res
;
tmp_res
.
x
=
tmp
.
x
*
b
;
tmp_res
.
y
=
tmp
.
y
*
b
;
uint32_t
res
=
float2_to_half2
(
tmp_res
);
return
res
;
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
float
b
)
{
uint2
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
return
res
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
float
b
)
{
uint4
res
;
res
.
x
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
x
,
b
);
res
.
y
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
y
,
b
);
res
.
z
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
z
,
b
);
res
.
w
=
mul
<
uint32_t
,
uint32_t
,
float
>
(
a
.
w
,
b
);
return
res
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float
b
)
{
float2
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
return
res
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float
b
)
{
float4
res
;
res
.
x
=
a
.
x
*
b
;
res
.
y
=
a
.
y
*
b
;
res
.
z
=
a
.
z
*
b
;
res
.
w
=
a
.
w
*
b
;
return
res
;
}
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
...
...
@@ -445,11 +491,15 @@ inline __device__ Float8_ cast_to_float(uint4 u) {
}
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
K_vec
qk_vec
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
q
[
0
],
k
[
0
]);
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
K_vec
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
0
],
inv_sqrt_dh
);
K_vec
qk_vec
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
inv_q
,
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
inv_q
=
mul
<
K_vec
,
K_vec
,
float
>
(
q
[
ii
],
inv_sqrt_dh
);
qk_vec
=
fma
(
inv_q
,
k
[
ii
],
qk_vec
);
}
float
qk
=
sum
(
qk_vec
);
...
...
@@ -463,8 +513,10 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
);
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
],
float
inv_sqrt_dh
)
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
,
inv_sqrt_dh
);
}
};
...
...
@@ -706,7 +758,9 @@ __global__ void masked_multihead_attention_kernel(
}
}
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q
,
k
)
*
params
.
inv_sqrt_dh
;
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q
,
k
,
params
.
inv_sqrt_dh
);
// bool is_mask = false;
if
(
ti
<
params
.
timestep
&&
tid
%
THREADS_PER_KEY
==
0
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录