Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f9d5ae4e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f9d5ae4e
编写于
5月 16, 2022
作者:
W
WangXi
提交者:
GitHub
5月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fused_multi_transformer add fused softmax mask (#42636)
上级
661d0800
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
235 addition
and
20 deletion
+235
-20
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+18
-11
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+6
-8
paddle/fluid/operators/fused/fused_softmax_mask.cu.h
paddle/fluid/operators/fused/fused_softmax_mask.cu.h
+204
-0
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
.../fluid/tests/unittests/test_fused_multi_transformer_op.py
+7
-1
未找到文件。
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
f9d5ae4e
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
...
...
@@ -148,18 +149,24 @@ class FMHARef {
stride_b
);
int
softmax_axis
=
-
1
;
if
(
src_mask_tensor
!=
nullptr
)
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
Tensor
*>
outs
;
ins
.
emplace_back
(
qk_out_tensor
);
ins
.
emplace_back
(
src_mask_tensor
);
outs
.
emplace_back
(
src_mask_out_tensor
);
int
elewise_add_axis
=
-
1
;
paddle
::
operators
::
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
elewise_add_axis
,
AddFunctor
<
T
>
());
if
(
src_mask_out_tensor
==
nullptr
&&
seq_len_
==
out_seq_len
)
{
LaunchFusedSoftmaxMaskKernel
<
T
>
(
qk_out_data
,
src_mask_tensor
->
data
<
T
>
(),
softmax_out_data
,
batch_size_
,
num_head_
,
seq_len_
,
dev_ctx_
.
stream
());
}
else
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
Tensor
*>
outs
;
ins
.
emplace_back
(
qk_out_tensor
);
ins
.
emplace_back
(
src_mask_tensor
);
outs
.
emplace_back
(
src_mask_out_tensor
);
int
elewise_add_axis
=
-
1
;
paddle
::
operators
::
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
elewise_add_axis
,
AddFunctor
<
T
>
());
phi
::
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
src_mask_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
phi
::
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
src_mask_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
}
}
else
{
phi
::
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
qk_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
...
...
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
f9d5ae4e
...
...
@@ -1084,11 +1084,9 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
*
qk_out_data
=
qk_out
.
mutable_data
<
T
>
({
bsz
,
num_head
,
seq_len
,
out_seq_len
},
place
);
Tensor
s
rc_mask_out
,
s
oftmax_out
;
Tensor
softmax_out
;
Tensor
attn_dropout_mask_out
,
attn_dropout_out
;
Tensor
qktv_out
,
fmha_out
;
auto
*
src_mask_out_data
=
src_mask_out
.
mutable_data
<
T
>
(
{
bsz
,
num_head
,
seq_len
,
out_seq_len
},
place
);
auto
*
softmax_out_data
=
softmax_out
.
mutable_data
<
T
>
(
{
bsz
,
num_head
,
seq_len
,
out_seq_len
},
place
);
...
...
@@ -1219,10 +1217,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
1.
/
sqrt
(
dim_head
));
}
else
if
(
cache_kv_out
)
{
// generation context stage
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
qkv_out
,
nullptr
,
src_mask
,
&
transpose_out_2
,
nullptr
,
&
qk_out
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
fmha_compute
.
ComputeForward
(
qkv_out
,
nullptr
,
src_mask
,
&
transpose_out_2
,
nullptr
,
&
qk_out
,
nullptr
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
// [3, bsz, num_head, seq_len, head_dim]
T
*
qkv_data
=
transpose_out_2_data
;
int64_t
q_size
=
bsz
*
seq_len
*
num_head
*
dim_head
;
...
...
@@ -1245,7 +1243,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// TODO(wangxi): can remove dropout in inference
fmha_compute
.
ComputeForward
(
qkv_out
,
cache_kv
,
src_mask
,
&
transpose_out_2
,
cache_kv_out
,
&
qk_out
,
&
src_mask_out
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
qk_out
,
nullptr
,
&
softmax_out
,
&
attn_dropout_mask_out
,
&
attn_dropout_out
,
&
qktv_out
,
&
fmha_out
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...
...
paddle/fluid/operators/fused/fused_softmax_mask.cu.h
0 → 100644
浏览文件 @
f9d5ae4e
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
namespace
plat
=
paddle
::
platform
;
#define FINAL_MASK 0xffffffff
#define DIV_UP(x, y) (((x) + (y)-1) / (y))
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
);
return
val
;
}
template
<
typename
T
>
__inline__
__device__
T
warpReduceMax
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
=
max
(
val
,
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
));
return
val
;
}
inline
int
ElementsCeil
(
int
seq_len
)
{
int
elements
=
1
;
while
(
elements
*
32
<
seq_len
)
elements
*=
2
;
return
elements
;
}
template
<
typename
T
,
int
VEC_SIZE
,
int
ELEMENTS_PER_THREADS
>
__global__
void
FusedSoftmaxMaskVecKernel
(
T
*
dst
,
const
T
*
src
,
const
T
*
mask
,
int
seq_len
)
{
constexpr
int
block_size
=
128
;
constexpr
int
warp_size
=
32
;
constexpr
int
warps_per_block
=
block_size
/
warp_size
;
// blockDim/threadIdx = (warp_size, warps_per_block)
// gridDim/blockIdx = (DIV_UP(seq_len, warps_per_block), batch_size, head_num)
// every block processes 4(warps_per_block) sequences
// seq_id = seq_id * 4 + warp_id, eg.seq_len=128, 127=31*4+3
int
seq_id
=
blockIdx
.
x
*
warps_per_block
+
threadIdx
.
y
;
if
(
seq_id
>=
seq_len
)
return
;
// ((bid*head_num + hid)*seq_len + seq_id) * seq_len
int
offset
=
((
blockIdx
.
y
*
gridDim
.
z
+
blockIdx
.
z
)
*
seq_len
+
seq_id
)
*
seq_len
;
// (bid * seq_len + seq_id) * seq_len
int
mask_offset
=
(
blockIdx
.
y
*
seq_len
+
seq_id
)
*
seq_len
;
src
+=
offset
;
dst
+=
offset
;
mask
+=
mask_offset
;
static_assert
(
ELEMENTS_PER_THREADS
%
VEC_SIZE
==
0
,
""
);
constexpr
int
VEC_NUMS
=
ELEMENTS_PER_THREADS
/
VEC_SIZE
;
using
VecT
=
phi
::
AlignedVector
<
T
,
VEC_SIZE
>
;
VecT
elements
[
VEC_NUMS
];
VecT
tmp_mask
;
float
max_val
=
-
std
::
numeric_limits
<
float
>::
infinity
();
for
(
int
i
=
0
;
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
<
seq_len
;
++
i
)
{
phi
::
Load
(
src
+
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
,
&
elements
[
i
]);
phi
::
Load
(
mask
+
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
,
&
tmp_mask
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
// TODO(wangxi): vec add
elements
[
i
][
j
]
+=
tmp_mask
[
j
];
max_val
=
max
(
max_val
,
static_cast
<
float
>
(
elements
[
i
][
j
]));
}
}
max_val
=
warpReduceMax
(
max_val
);
float
sum_val
=
0
;
for
(
int
i
=
0
;
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
<
seq_len
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
float
tmp
=
__expf
(
static_cast
<
float
>
(
elements
[
i
][
j
])
-
max_val
);
sum_val
+=
tmp
;
elements
[
i
][
j
]
=
static_cast
<
T
>
(
tmp
);
}
}
sum_val
=
warpReduceSum
(
sum_val
);
float
mean_val
=
__fdividef
(
1.0
f
,
sum_val
+
1e-6
f
);
for
(
int
i
=
0
;
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
<
seq_len
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
float
tmp
=
static_cast
<
float
>
(
elements
[
i
][
j
])
*
mean_val
;
elements
[
i
][
j
]
=
static_cast
<
T
>
(
tmp
);
}
phi
::
Store
(
elements
[
i
],
dst
+
(
i
*
warp_size
+
threadIdx
.
x
)
*
VEC_SIZE
);
}
}
#define SOFTMAX_MASK_KERNEL(VEC_SIZE, ELEMENTS) \
FusedSoftmaxMaskVecKernel<T, VEC_SIZE, \
ELEMENTS><<<grid, block, 0, stream>>>( \
dst, src, mask, seq_len)
// FIXME(wangxi): It is found that the performance of VEC_SIZE=2 is better
// than that of =4 and =8. Further analysis of the kernel is needed later.
// #define SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS) \
// do { \
// if (sizeof(T) == 2 && seq_len % 8 == 0) { \
// FusedSoftmaxMaskVecKernel<plat::float16, 8, ELEMENTS> \
// <<<grid, block, 0, stream>>>( \
// (plat::float16*)dst, (const plat::float16*)src, mask, seq_len); \
// } \
// else if (seq_len % 4 == 0) SOFTMAX_MASK_KERNEL(4, ELEMENTS); \
// else if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, ELEMENTS); \
// else SOFTMAX_MASK_KERNEL(1, ELEMENTS); \
// } while(0)
#define SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS) \
do { \
if (seq_len % 2 == 0) { \
SOFTMAX_MASK_KERNEL(2, ELEMENTS); \
} else { \
SOFTMAX_MASK_KERNEL(1, ELEMENTS); \
} \
} while (0)
#define CASE_SOFTMAX_MASK_KERNEL(ELEMENTS) \
case ELEMENTS: { \
SELECT_SOFTMAX_MASK_KERNEL(ELEMENTS); \
break; \
}
// template <typename T, typename MaskT = T>
template
<
typename
T
>
void
LaunchFusedSoftmaxMaskKernel
(
const
T
*
src
,
const
T
*
mask
,
T
*
dst
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
cudaStream_t
stream
)
{
PADDLE_ENFORCE_EQ
(
seq_len
>
0
&&
seq_len
<=
4096
,
true
,
platform
::
errors
::
InvalidArgument
(
"seq_len must be between (0, 4096] "
"received the seq_len is %d"
,
seq_len
));
constexpr
int
block_size
=
128
;
constexpr
int
warp_size
=
32
;
constexpr
int
warps_per_block
=
block_size
/
warp_size
;
// put head_num to the outside for mask
dim3
block
(
warp_size
,
warps_per_block
);
dim3
grid
(
DIV_UP
(
seq_len
,
warps_per_block
),
batch_size
,
head_num
);
// clang-format off
int
elements
=
ElementsCeil
(
seq_len
);
switch
(
elements
)
{
case
1
:
{
// <=32
SOFTMAX_MASK_KERNEL
(
1
,
1
);
break
;
}
case
2
:
{
// <=64
// if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 2);
// else SOFTMAX_MASK_KERNEL(1, 2);
SELECT_SOFTMAX_MASK_KERNEL
(
2
);
break
;
}
case
4
:
{
// <=128
// if (seq_len % 4 == 0) SOFTMAX_MASK_KERNEL(4, 4);
// else if (seq_len % 2 == 0) SOFTMAX_MASK_KERNEL(2, 4);
// else SOFTMAX_MASK_KERNEL(1, 4);
SELECT_SOFTMAX_MASK_KERNEL
(
4
);
break
;
}
CASE_SOFTMAX_MASK_KERNEL
(
8
);
// <=256
CASE_SOFTMAX_MASK_KERNEL
(
16
);
// <=512
CASE_SOFTMAX_MASK_KERNEL
(
32
);
// <=1024
CASE_SOFTMAX_MASK_KERNEL
(
64
);
// <=2048
CASE_SOFTMAX_MASK_KERNEL
(
128
);
// <=4096
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"seq_len must be between (0, 4096], received the seq_len is %d"
,
seq_len
));
}
// clang-format on
}
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
浏览文件 @
f9d5ae4e
...
...
@@ -109,6 +109,7 @@ class TestFusedMultiTransformerOp(OpTest):
self
.
x_type
=
np
.
float32
self
.
attn_mask_type
=
np
.
float64
#self.attn_mask_type = np.bool
self
.
pre_layer_norm
=
True
self
.
has_attn_mask
=
True
...
...
@@ -168,6 +169,11 @@ class TestFusedMultiTransformerOp(OpTest):
self
.
attn_mask
=
(
self
.
attn_mask
-
1.0
)
*
1e4
else
:
self
.
attn_mask
=
(
np
.
tril
(
self
.
attn_mask
)
-
1.0
)
*
1e4
elif
self
.
attn_mask_type
==
np
.
bool
:
if
self
.
has_cache_kv
and
not
self
.
gen_cache_kv
:
self
.
attn_mask
[:,
:,
:,
-
2
]
=
0
else
:
self
.
attn_mask
=
np
.
tril
(
self
.
attn_mask
)
else
:
raise
ValueError
(
"'attn_mask_type' should be 'int64' or 'float64'."
)
...
...
@@ -394,7 +400,7 @@ class TestFusedMultiTransformerOp(OpTest):
epsilon
=
1e-05
ln2_epsilon
=
1e-05
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
and
self
.
attn_mask_type
!=
np
.
bool
:
attn_mask
=
_convert_attention_mask
(
attn_mask
,
x
.
dtype
)
qkv_weights
,
qkv_biases
=
[],
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录