Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
03f9e598
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
03f9e598
编写于
6月 20, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
6月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support more dimensions in MMHA (#43612)
* support more dimensions * fix
上级
2ddbc647
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
112 addition
and
68 deletion
+112
-68
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+112
-68
未找到文件。
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
03f9e598
...
...
@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float16
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
float16
,
256
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
...
...
@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
__global__
void
masked_multihead_attention_kernel
(
Masked_multihead_attention_params
<
T
>
params
,
int
pad_active_groups
)
{
Masked_multihead_attention_params
<
T
>
params
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert
(
Dh
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh
%
THREADS_PER_VALUE
==
0
,
""
);
static_assert
(
Dh
_MAX
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh
_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
...
...
@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel(
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
__shared__
T
q_smem
[
Dh
];
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
x
;
...
...
@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel(
// qkv [B, S=1, 3, num_head, head_dim]
int
qkv_base_offset
=
bi
*
3
*
params
.
num_head
*
Dh
+
hi
*
Dh
;
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
static_assert
(
Dh
%
QK_VEC_SIZE
==
0
&&
Dh
/
QK_VEC_SIZE
<=
WARP_SIZE
,
""
);
constexpr
int
QK_VECS_PER_WARP
=
Dh
/
QK_VEC_SIZE
;
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
...
...
@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel(
int
qk_offset
=
qkv_base_offset
+
tid
*
QK_VEC_SIZE
;
int
qk_bias_offset
=
hi
*
Dh
+
tid
*
QK_VEC_SIZE
;
Qk_vec
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
]);
Qk_vec
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
]);
Qk_vec
q_bias
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
]);
Qk_vec
k_bias
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
]);
Qk_vec
q
;
zero
(
q
);
q
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
])
:
q
;
Qk_vec
k
;
zero
(
k
);
k
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
])
:
k
;
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
k_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
])
:
k_bias
;
q
=
add
(
q
,
q_bias
);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
...
...
@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel(
int
offset
=
bhi
*
params
.
max_seq_length
*
Dh
+
co
*
params
.
max_seq_length
*
QK_ELTS_IN_16B
+
params
.
timestep
*
QK_ELTS_IN_16B
+
ci
;
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
}
qk
=
dot
<
Qk_vec
,
Qk_vec
>
(
q
,
k
);
}
if
(
tid
<
WARP_SIZE
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
if
(
tid
==
0
)
{
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
if
(
tid
==
0
)
{
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
}
__syncthreads
();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...
...
@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel(
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
static_assert
(
Dh
%
K_VEC_SIZE
==
0
,
""
);
constexpr
int
K_ELTS_PER_THREAD
=
Dh
/
THREADS_PER_KEY
;
static_assert
(
Dh
_MAX
%
K_VEC_SIZE
==
0
,
""
);
constexpr
int
K_ELTS_PER_THREAD
=
Dh
_MAX
/
THREADS_PER_KEY
;
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
int
ko
=
tid
/
THREADS_PER_KEY
;
int
ki
=
(
tid
%
THREADS_PER_KEY
)
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
,
""
);
K_vec
q
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
K_VECS_PER_THREAD
;
++
i
)
{
...
...
@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel(
for
(
int
ti
=
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
max_seq_length
+
ti
;
if
(
ti
<
params
.
timestep
)
{
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
]);
k
[
ii
]
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
max_seq_length
)
?
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
:
k_vec_zero
;
}
}
...
...
@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel(
}
__syncthreads
();
constexpr
int
V_VEC_SIZE
=
Dh
/
THREADS_PER_VALUE
;
constexpr
int
V_VEC_SIZE
=
Dh
_MAX
/
THREADS_PER_VALUE
;
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
int
vo
=
tid
/
THREADS_PER_VALUE
;
...
...
@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel(
zero
(
out
);
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
if
(
vo
<
V_PER_ITER
)
{
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
vo
;
ti
<
params
.
timestep
;
ti
+=
V_PER_ITER
)
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
ti
*
Dh
]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...
...
@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads
();
#endif
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
))
{
V_vec
v_bias
;
zero
(
v_bias
);
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
)
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv
[
2
*
params
.
num_head
*
Dh
+
qkv_base_offset
+
vi
]);
V_vec
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv_bias
[
2
*
params
.
num_head
*
Dh
+
hi
*
Dh
+
vi
]);
v
=
add
(
v
,
v_bias
);
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
params
.
timestep
*
Dh
])
=
v
;
...
...
@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads
();
if
(
vo
<
pad_active_groups
/
2
)
{
zero
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]));
}
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
pad_active_groups
;
active_groups
>=
2
;
active_groups
/=
2
)
{
int
midpoint
=
active_groups
/
2
;
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
int
midpoint
=
active_groups
/
2
;
if
(
vo
>=
midpoint
&&
vo
<
active_groups
)
{
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
)
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
__syncthreads
();
if
(
vo
<
midpoint
)
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
if
(
vo
==
0
)
{
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
)
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
...
...
@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel(
template
<
typename
T
>
inline
size_t
smem_size_in_bytes
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
int
threads_per_value
,
int
threads_per_block
,
int
pad_active_groups
)
{
int
threads_per_value
,
int
threads_per_block
)
{
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
logits_sz
=
0
;
...
...
@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes(
#endif
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
int
rows_per_red
=
pad_active_groups
;
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
size_t
red_sz
=
rows_per_red
*
dim_head
*
sizeof
(
T
)
/
2
;
return
max
(
softmax_sz
,
red_sz
);
}
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
int pad_active_groups = \
1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \
THDS_PER_BLOCK, pad_active_groups); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params, pad_active_groups)
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
size_t smem_sz = \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_head, params.batch_size); \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
>
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh
*
sizeof
(
T
)
/
16
;
constexpr
int
THREADS_PER_VALUE
=
Dh
_MAX
*
sizeof
(
T
)
/
16
;
if
(
params
.
timestep
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
stream
);
}
else
if
(
params
.
timestep
<
2048
)
{
...
...
@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
params
.
inv_sqrt_dh
=
inv_sqrt_dh
;
switch
(
dim_head
)
{
case
10
:
fmha_launch_kernel
<
T
,
10
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
26
:
fmha_launch_kernel
<
T
,
26
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
32
:
fmha_launch_kernel
<
T
,
32
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
...
...
@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
case
128
:
fmha_launch_kernel
<
T
,
128
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
case
192
:
fmha_launch_kernel
<
T
,
192
,
256
>
(
params
,
dev_ctx
.
stream
());
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"dim_head = %d is unsupport, only support "
"dim_head = 32, 64, 96 or 128 for now."
,
dim_head
));
"Dim_head = %d is unsupport!"
,
dim_head
));
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录