Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a5a8d144
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a5a8d144
编写于
12月 16, 2019
作者:
Z
zhaoyuchen2018
提交者:
GitHub
12月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix softmax cuda bug (#21720)
* Fix softmax cuda bug * Refine multihead log and softmax logic
上级
943a4449
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
9 deletion
+29
-9
paddle/fluid/operators/fused/multihead_matmul_op.cc
paddle/fluid/operators/fused/multihead_matmul_op.cc
+26
-5
paddle/fluid/operators/fused/multihead_matmul_op.cu
paddle/fluid/operators/fused/multihead_matmul_op.cu
+1
-3
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
...e/fluid/tests/unittests/test_fused_multihead_matmul_op.py
+2
-1
未找到文件。
paddle/fluid/operators/fused/multihead_matmul_op.cc
浏览文件 @
a5a8d144
...
...
@@ -84,15 +84,36 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
dim_bias_q
[
0
],
dim_bias_v
[
0
],
"Multihead input bias should have same batch size"
);
PADDLE_ENFORCE_EQ
(
dim_bias_q
[
1
],
dim_bias_k
[
1
],
"Multihead input bias should have same size"
);
PADDLE_ENFORCE_EQ
(
dim_bias_q
[
1
],
dim_bias_v
[
1
],
"Multihead input bias should have same size"
);
auto
dim_bias_qk
=
context
->
GetInputDim
(
"BiasQK"
);
PADDLE_ENFORCE_GT
(
dim_bias_qk
.
size
(),
3
,
"Multihead input bias qk should be at least 4-D tensor."
);
int
b_size
=
dim_bias_q
.
size
()
-
1
;
int
size
=
dim_q
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
dim_bias_q
[
b_size
],
dim_q
[
size
],
platform
::
errors
::
InvalidArgument
(
"bias_q's last dim size should equal to"
" q last dim size, but bias_q's size is:%d q is:%d"
,
dim_bias_q
[
b_size
],
dim_q
[
size
]));
PADDLE_ENFORCE_EQ
(
dim_bias_k
[
b_size
],
dim_k
[
size
],
platform
::
errors
::
InvalidArgument
(
"bias_k's last dim size should equal to"
" k last dim size, but bias_k's size is:%d k is:%d"
,
dim_bias_k
[
b_size
],
dim_k
[
size
]));
PADDLE_ENFORCE_EQ
(
dim_bias_v
[
b_size
],
dim_v
[
size
],
platform
::
errors
::
InvalidArgument
(
"bias_v's last dim size should equal to"
" v last dim size, but bias_v's size is:%d v is:%d"
,
dim_bias_v
[
b_size
],
dim_v
[
size
]));
PADDLE_ENFORCE_EQ
(
dim_q
[
0
],
dim_bias_qk
[
0
],
platform
::
errors
::
InvalidArgument
(
"q should have same batch size"
"with bias_qk, but q's batch size:%d not equal to "
"bias_qk's batch size:%d"
,
dim_q
[
0
],
dim_bias_qk
[
0
]));
int
head_number
=
context
->
Attrs
().
Get
<
int
>
(
"head_number"
);
PADDLE_ENFORCE_GT
(
head_number
,
1
,
"Multihead input head number should be at least 1."
);
...
...
paddle/fluid/operators/fused/multihead_matmul_op.cu
浏览文件 @
a5a8d144
...
...
@@ -196,15 +196,13 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
int
seq_id
=
blockIdx
.
x
%
seq_len
;
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
int
bias_offset
=
blockIdx
.
x
%
(
head_num
*
seq_len
)
*
seq_len
;
__shared__
float
s_sum
,
s_max
;
float
qk
=
threadIdx
.
x
<
seq_len
?
static_cast
<
float
>
((
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
bias
_offset
]))
bias_qk_
[
threadIdx
.
x
+
qk
_offset
]))
:
0.0
f
;
float
tmp
=
threadIdx
.
x
<
seq_len
?
static_cast
<
float
>
(
qk
)
:
-
1e20
f
;
...
...
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
浏览文件 @
a5a8d144
...
...
@@ -54,7 +54,8 @@ class TestFusedMultiheadMatmulOp(OpTest):
self
.
BiasK
=
np
.
random
.
random
((
1
,
w
)).
astype
(
"float32"
)
self
.
BiasV
=
np
.
random
.
random
((
1
,
w
)).
astype
(
"float32"
)
self
.
BiasQK
=
np
.
random
.
random
(
(
1
,
self
.
head_number
,
self
.
seq_len
,
self
.
seq_len
)).
astype
(
"float32"
)
(
self
.
batch_size
,
self
.
head_number
,
self
.
seq_len
,
self
.
seq_len
)).
astype
(
"float32"
)
# Compute Q path
fc_q
=
self
.
Q
+
self
.
BiasQ
reshape_q
=
np
.
reshape
(
fc_q
,
(
self
.
batch_size
,
self
.
seq_len
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录