Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
25a0b46d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
25a0b46d
编写于
9月 04, 2023
作者:
D
duanyanhui
提交者:
GitHub
9月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize softmax_mask_fuse (#56877)
上级
d38cd6ce
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
112 addition
and
44 deletion
+112
-44
paddle/fluid/eager/amp_auto_cast.h
paddle/fluid/eager/amp_auto_cast.h
+4
-0
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu
+87
-44
test/legacy_test/test_softmax_mask_fuse_op.py
test/legacy_test/test_softmax_mask_fuse_op.py
+21
-0
未找到文件。
paddle/fluid/eager/amp_auto_cast.h
浏览文件 @
25a0b46d
...
...
@@ -75,6 +75,10 @@ inline paddle::Tensor AmpAutoCast(const std::string& input_name,
input_name
!=
"X"
)
{
return
input
;
}
if
(
op_name
==
"fused_softmax_mask"
&&
input_name
==
"Mask"
&&
input
.
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
return
input
;
}
if
(
dst_dtype
==
phi
::
DataType
::
FLOAT16
)
{
if
(
op_name
==
"run_program"
)
{
return
input
;
...
...
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_kernel.cu
浏览文件 @
25a0b46d
...
...
@@ -22,9 +22,9 @@ namespace phi {
namespace
fusion
{
// T == fp16
template
<
typename
T
,
int
pow2_index
>
template
<
typename
T
,
typename
MT
,
int
pow2_index
>
__global__
void
SoftmaxMaskFuseGPUKernel
(
const
T
*
x_data
,
const
T
*
mask_data
,
const
M
T
*
mask_data
,
T
*
y_data
,
int
batch_count
,
int
key_seq_len
)
{
...
...
@@ -62,7 +62,7 @@ __global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
// using float for all inter compute
float
data
[
kLocalBatchSize
][
kLocalIterations
];
T
temp_data
[
kOneLoadingCounts
];
T
temp_mask
[
kOneLoadingCounts
];
M
T
temp_mask
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
...
...
@@ -151,7 +151,6 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const
DenseTensor
&
mask
,
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
mask_data
=
mask
.
data
<
T
>
();
auto
*
y_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
x_dim
=
x
.
dims
();
...
...
@@ -226,46 +225,90 @@ void FusedSoftmaxMaskKernel(const Context& dev_ctx,
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch the kernel based on the pow2_index
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseGPUKernel
<
T
,
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseGPUKernel
<
T
,
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseGPUKernel
<
T
,
7
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseGPUKernel
<
T
,
8
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseGPUKernel
<
T
,
9
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseGPUKernel
<
T
,
10
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseGPUKernel
<
T
,
11
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseGPUKernel
<
T
,
12
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseGPUKernel
<
T
,
13
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
if
(
mask
.
dtype
()
==
x
.
dtype
())
{
auto
*
mask_data
=
mask
.
data
<
T
>
();
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
7
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
8
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
9
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
10
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
11
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
12
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseGPUKernel
<
T
,
T
,
13
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
else
if
(
mask
.
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
auto
*
mask_data
=
mask
.
data
<
float
>
();
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
7
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
8
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
9
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
10
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
11
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
12
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseGPUKernel
<
T
,
float
,
13
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
mask_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
...
...
test/legacy_test/test_softmax_mask_fuse_op.py
浏览文件 @
25a0b46d
...
...
@@ -78,6 +78,27 @@ class TestSoftmaxMaskFuseOp0(OpTest):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"X"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestSoftmaxMaskFuseOp01
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fused_softmax_mask"
self
.
python_api
=
paddle
.
incubate
.
softmax_mask_fuse
x
=
np
.
random
.
random
((
1
,
1
,
8
,
32
)).
astype
(
"float16"
)
mask
=
np
.
random
.
randint
(
0
,
2
,
(
1
,
1
,
8
,
32
)).
astype
(
"float32"
)
mask_input
=
np
.
where
(
mask
==
1
,
-
10000.0
,
mask
)
self
.
inputs
=
{
'X'
:
x
,
'Mask'
:
mask_input
}
rst
=
_get_softmax
(
x
,
mask_input
)
self
.
outputs
=
{
'Out'
:
rst
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
))
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"X"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录