Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
42f35841
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
42f35841
编写于
11月 18, 2022
作者:
F
feng_shuai
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: supoort huge length of attention (#48053)
上级
85598e31
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
44 addition
and
21 deletion
+44
-21
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+44
-21
未找到文件。
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
42f35841
...
...
@@ -783,6 +783,19 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
}
#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \
do { \
block.x /= REPEAT_THREAD; \
grid.x /= 4; \
constexpr int NUM = 4; \
softmax_kernel_with_mask<half, REPEAT_THREAD, NUM> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_), \
(const half *)bias_qk, \
batch_size, \
head_num, \
seq_len); \
} while (0)
template
<
typename
T
>
inline
void
MatMulWithHeadQK
(
const
phi
::
GPUContext
&
context
,
int
head_num
,
...
...
@@ -843,22 +856,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"
));
#else
constexpr
int
ITEMS_PER_THREAD
=
1
;
bool
is_half2
=
true
;
dim3
grid
(
seq_len
,
batch_size
,
head_num
);
dim3
block
((
seq_len
/
2
+
31
)
/
32
*
32
);
block
.
x
/=
ITEMS_PER_THREAD
;
assert
(
block
.
x
<=
1024
);
assert
(
grid
.
x
%
4
==
0
);
grid
.
x
/=
4
;
constexpr
int
NUM
=
4
;
softmax_kernel_with_mask
<
half
,
ITEMS_PER_THREAD
,
NUM
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
half
*>
(
qk_buf_
),
(
const
half
*
)
bias_qk
,
batch_size
,
head_num
,
seq_len
);
SOFTMAX_KERNEL_WITH_MASK
(
1
);
#endif
}
else
{
SoftmaxKernelWithEltadd2
<
__half2
><<<
grid
,
block
,
0
,
stream
>>>
(
...
...
@@ -887,6 +887,28 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
head_num
,
seq_len
/
2
,
FINAL_MASK
);
}
else
{
if
(
bias_is_mask
)
{
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ
(
bias_is_mask
,
false
,
platform
::
errors
::
InvalidArgument
(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"
));
#else
dim3
grid
(
seq_len
,
batch_size
,
head_num
);
dim3
block
((
seq_len
/
2
+
31
)
/
32
*
32
);
if
(
block
.
x
>
0
&&
block
.
x
<=
1024
)
{
SOFTMAX_KERNEL_WITH_MASK
(
1
);
}
else
if
(
block
.
x
<=
2048
)
{
SOFTMAX_KERNEL_WITH_MASK
(
2
);
}
else
if
(
block
.
x
<=
4096
)
{
SOFTMAX_KERNEL_WITH_MASK
(
4
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Cannot support the length of attention > 8192."
));
}
#endif
}
else
{
SoftmaxKernelWithEltaddForLarge2
<
__half2
><<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
__half2
*>
(
qk_buf_
),
...
...
@@ -896,6 +918,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
seq_len
/
2
,
FINAL_MASK
);
}
}
}
else
{
SoftmaxKernelWithEltaddForLarge
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
,
FINAL_MASK
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录