Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
03f9e598
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
03f9e598
编写于
6月 20, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
6月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support more dimensions in MMHA (#43612)
* support more dimensions * fix
上级
2ddbc647
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
112 addition
and
68 deletion
+112
-68
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+112
-68
未找到文件。
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
03f9e598
...
@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {};
...
@@ -114,9 +114,11 @@ template <typename T, int Dh> struct Qk_vec_ {};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float16
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
float16
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
float16
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
float16
,
256
>
{
using
Type
=
uint4
;
};
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{};
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
...
@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT
...
@@ -532,11 +534,11 @@ inline __device__ void zero(T &dst) { // NOLINT
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
int
THREADS_PER_KEY
,
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
int
THREADS_PER_KEY
,
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
int
THREADS_PER_VALUE
,
int
THREADS_PER_BLOCK
>
__global__
void
masked_multihead_attention_kernel
(
__global__
void
masked_multihead_attention_kernel
(
Masked_multihead_attention_params
<
T
>
params
,
int
pad_active_groups
)
{
Masked_multihead_attention_params
<
T
>
params
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert
(
Dh
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh
_MAX
%
THREADS_PER_KEY
==
0
,
""
);
static_assert
(
Dh
%
THREADS_PER_VALUE
==
0
,
""
);
static_assert
(
Dh
_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
...
@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -552,7 +554,8 @@ __global__ void masked_multihead_attention_kernel(
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
__shared__
T
q_smem
[
Dh
];
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
const
int
bi
=
blockIdx
.
y
;
const
int
bi
=
blockIdx
.
y
;
const
int
hi
=
blockIdx
.
x
;
const
int
hi
=
blockIdx
.
x
;
...
@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -565,10 +568,11 @@ __global__ void masked_multihead_attention_kernel(
// qkv [B, S=1, 3, num_head, head_dim]
// qkv [B, S=1, 3, num_head, head_dim]
int
qkv_base_offset
=
bi
*
3
*
params
.
num_head
*
Dh
+
hi
*
Dh
;
int
qkv_base_offset
=
bi
*
3
*
params
.
num_head
*
Dh
+
hi
*
Dh
;
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
static_assert
(
Dh
%
QK_VEC_SIZE
==
0
&&
Dh
/
QK_VEC_SIZE
<=
WARP_SIZE
,
""
);
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
constexpr
int
QK_VECS_PER_WARP
=
Dh
/
QK_VEC_SIZE
;
// Use block reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// cache_k, [B, num_head, head_dim / x, max_seq_len, x]
// x == 4/8 for FP32/FP16, 128bit, 16Byte
// x == 4/8 for FP32/FP16, 128bit, 16Byte
...
@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -584,13 +588,29 @@ __global__ void masked_multihead_attention_kernel(
int
qk_offset
=
qkv_base_offset
+
tid
*
QK_VEC_SIZE
;
int
qk_offset
=
qkv_base_offset
+
tid
*
QK_VEC_SIZE
;
int
qk_bias_offset
=
hi
*
Dh
+
tid
*
QK_VEC_SIZE
;
int
qk_bias_offset
=
hi
*
Dh
+
tid
*
QK_VEC_SIZE
;
Qk_vec
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
]);
Qk_vec
q
;
Qk_vec
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
]);
zero
(
q
);
q
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
Qk_vec
q_bias
=
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_base
[
qk_offset
])
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
]);
:
q
;
Qk_vec
k_bias
=
Qk_vec
k
;
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
]);
zero
(
k
);
k
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_base
[
qk_offset
])
:
k
;
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
q_bias_base
[
qk_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
k_bias
=
(
Dh
==
Dh_MAX
||
tid
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
k_bias_base
[
qk_bias_offset
])
:
k_bias
;
q
=
add
(
q
,
q_bias
);
q
=
add
(
q
,
q_bias
);
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
// TODO(wangxi): See this https://github.com/microsoft/unilm/issues/510
...
@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -604,24 +624,33 @@ __global__ void masked_multihead_attention_kernel(
int
offset
=
bhi
*
params
.
max_seq_length
*
Dh
+
int
offset
=
bhi
*
params
.
max_seq_length
*
Dh
+
co
*
params
.
max_seq_length
*
QK_ELTS_IN_16B
+
co
*
params
.
max_seq_length
*
QK_ELTS_IN_16B
+
params
.
timestep
*
QK_ELTS_IN_16B
+
ci
;
params
.
timestep
*
QK_ELTS_IN_16B
+
ci
;
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
cache_kv
[
offset
])
=
k
;
}
qk
=
dot
<
Qk_vec
,
Qk_vec
>
(
q
,
k
);
qk
=
dot
<
Qk_vec
,
Qk_vec
>
(
q
,
k
);
}
if
(
tid
<
WARP_SIZE
)
{
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
#pragma unroll
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
}
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
if
(
tid
==
0
)
{
}
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
}
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
if
(
tid
==
0
)
{
// NOTE(wangxi): mask must be 0.0
// T mask = params.attn_mask[
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk
*=
params
.
inv_sqrt_dh
;
qk_max
=
qk
;
qk_smem
[
params
.
timestep
]
=
qk
;
}
__syncthreads
();
__syncthreads
();
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
...
@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -635,13 +664,15 @@ __global__ void masked_multihead_attention_kernel(
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
static_assert
(
Dh
%
K_VEC_SIZE
==
0
,
""
);
static_assert
(
Dh
_MAX
%
K_VEC_SIZE
==
0
,
""
);
constexpr
int
K_ELTS_PER_THREAD
=
Dh
/
THREADS_PER_KEY
;
constexpr
int
K_ELTS_PER_THREAD
=
Dh
_MAX
/
THREADS_PER_KEY
;
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
int
ko
=
tid
/
THREADS_PER_KEY
;
int
ko
=
tid
/
THREADS_PER_KEY
;
int
ki
=
(
tid
%
THREADS_PER_KEY
)
*
K_VEC_SIZE
;
int
ki
=
(
tid
%
THREADS_PER_KEY
)
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
,
""
);
K_vec
q
[
K_VECS_PER_THREAD
];
K_vec
q
[
K_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
K_VECS_PER_THREAD
;
++
i
)
{
for
(
int
i
=
0
;
i
<
K_VECS_PER_THREAD
;
++
i
)
{
...
@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -657,11 +688,17 @@ __global__ void masked_multihead_attention_kernel(
for
(
int
ti
=
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
for
(
int
ti
=
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
max_seq_length
+
ti
;
int
jj
=
ii
*
params
.
max_seq_length
+
ti
;
if
(
ti
<
params
.
timestep
)
{
if
(
ti
<
params
.
timestep
)
{
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
]);
k
[
ii
]
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
max_seq_length
)
?
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
:
k_vec_zero
;
}
}
}
}
...
@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -727,7 +764,7 @@ __global__ void masked_multihead_attention_kernel(
}
}
__syncthreads
();
__syncthreads
();
constexpr
int
V_VEC_SIZE
=
Dh
/
THREADS_PER_VALUE
;
constexpr
int
V_VEC_SIZE
=
Dh
_MAX
/
THREADS_PER_VALUE
;
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
int
vo
=
tid
/
THREADS_PER_VALUE
;
int
vo
=
tid
/
THREADS_PER_VALUE
;
...
@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -747,7 +784,7 @@ __global__ void masked_multihead_attention_kernel(
zero
(
out
);
zero
(
out
);
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
if
(
vo
<
V_PER_ITER
)
{
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
vo
;
ti
<
params
.
timestep
;
ti
+=
V_PER_ITER
)
{
for
(
int
ti
=
vo
;
ti
<
params
.
timestep
;
ti
+=
V_PER_ITER
)
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
ti
*
Dh
]);
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
ti
*
Dh
]);
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...
@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -770,10 +807,12 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads
();
__syncthreads
();
#endif
#endif
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
))
{
V_vec
v_bias
;
zero
(
v_bias
);
if
(
vo
==
(
params
.
timestep
%
V_PER_ITER
)
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv
[
2
*
params
.
num_head
*
Dh
+
qkv_base_offset
+
vi
]);
&
params
.
qkv
[
2
*
params
.
num_head
*
Dh
+
qkv_base_offset
+
vi
]);
V_vec
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
qkv_bias
[
2
*
params
.
num_head
*
Dh
+
hi
*
Dh
+
vi
]);
&
params
.
qkv_bias
[
2
*
params
.
num_head
*
Dh
+
hi
*
Dh
+
vi
]);
v
=
add
(
v
,
v_bias
);
v
=
add
(
v
,
v_bias
);
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
params
.
timestep
*
Dh
])
=
v
;
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
params
.
timestep
*
Dh
])
=
v
;
...
@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -787,31 +826,31 @@ __global__ void masked_multihead_attention_kernel(
__syncthreads
();
__syncthreads
();
if
(
vo
<
pad_active_groups
/
2
)
{
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
zero
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]));
}
#pragma unroll
#pragma unroll
for
(
int
active_groups
=
pad_active_groups
;
active_groups
>=
2
;
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
active_groups
/=
2
)
{
int
midpoint
=
active_groups
/
2
;
int
midpoint
=
active_groups
/
2
;
if
(
vo
>=
midpoint
&&
vo
<
active_groups
)
{
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
)
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
out
);
#else
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
#endif
}
__syncthreads
();
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
__syncthreads
();
if
(
vo
<
midpoint
)
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
if
(
vo
==
0
)
{
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
)
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
out
);
...
@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel(
...
@@ -837,7 +876,7 @@ __global__ void masked_multihead_attention_kernel(
template
<
typename
T
>
template
<
typename
T
>
inline
size_t
smem_size_in_bytes
(
inline
size_t
smem_size_in_bytes
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
const
Masked_multihead_attention_params
<
T
>
&
params
,
int
dim_head
,
int
threads_per_value
,
int
threads_per_block
,
int
pad_active_groups
)
{
int
threads_per_value
,
int
threads_per_block
)
{
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
qk_sz
=
div_up
(
params
.
timestep
+
1
,
4
)
*
16
;
size_t
logits_sz
=
0
;
size_t
logits_sz
=
0
;
...
@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes(
...
@@ -848,27 +887,25 @@ inline size_t smem_size_in_bytes(
#endif
#endif
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
int
rows_per_red
=
pad_active_groups
;
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
size_t
red_sz
=
rows_per_red
*
dim_head
*
sizeof
(
T
)
/
2
;
size_t
red_sz
=
rows_per_red
*
dim_head
*
sizeof
(
T
)
/
2
;
return
max
(
softmax_sz
,
red_sz
);
return
max
(
softmax_sz
,
red_sz
);
}
}
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, stream) \
THDS_PER_BLOCK, stream) \
int pad_active_groups = \
size_t smem_sz = \
1 << static_cast<int>(ceil(std::log2(THDS_PER_BLOCK / THDS_PER_VALUE))); \
smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \
size_t smem_sz = smem_size_in_bytes<T>(params, Dh, THDS_PER_VALUE, \
dim3 grid(params.num_head, params.batch_size); \
THDS_PER_BLOCK, pad_active_groups); \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
dim3 grid(params.num_head, params.batch_size); \
THDS_PER_VALUE, THDS_PER_BLOCK> \
masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
THDS_PER_VALUE, THDS_PER_BLOCK> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params, pad_active_groups)
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
>
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
>
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
void
fmha_launch_kernel
(
const
Masked_multihead_attention_params
<
T
>
&
params
,
const
cudaStream_t
&
stream
)
{
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh
*
sizeof
(
T
)
/
16
;
constexpr
int
THREADS_PER_VALUE
=
Dh
_MAX
*
sizeof
(
T
)
/
16
;
if
(
params
.
timestep
<
32
)
{
if
(
params
.
timestep
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
stream
);
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
stream
);
}
else
if
(
params
.
timestep
<
2048
)
{
}
else
if
(
params
.
timestep
<
2048
)
{
...
@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
...
@@ -898,6 +935,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
params
.
inv_sqrt_dh
=
inv_sqrt_dh
;
params
.
inv_sqrt_dh
=
inv_sqrt_dh
;
switch
(
dim_head
)
{
switch
(
dim_head
)
{
case
10
:
fmha_launch_kernel
<
T
,
10
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
26
:
fmha_launch_kernel
<
T
,
26
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
case
32
:
case
32
:
fmha_launch_kernel
<
T
,
32
,
32
>
(
params
,
dev_ctx
.
stream
());
fmha_launch_kernel
<
T
,
32
,
32
>
(
params
,
dev_ctx
.
stream
());
break
;
break
;
...
@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
...
@@ -910,11 +953,12 @@ void fmha(const platform::CUDADeviceContext &dev_ctx, const Tensor &qkv_tensor,
case
128
:
case
128
:
fmha_launch_kernel
<
T
,
128
,
128
>
(
params
,
dev_ctx
.
stream
());
fmha_launch_kernel
<
T
,
128
,
128
>
(
params
,
dev_ctx
.
stream
());
break
;
break
;
case
192
:
fmha_launch_kernel
<
T
,
192
,
256
>
(
params
,
dev_ctx
.
stream
());
break
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"dim_head = %d is unsupport, only support "
"Dim_head = %d is unsupport!"
,
dim_head
));
"dim_head = 32, 64, 96 or 128 for now."
,
dim_head
));
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录