未验证 提交 6839a7b9 编写于 作者: L lzy 提交者: GitHub

make variable_length_memory_efficient_attention supports mask_broadcast_heads (#56673)

上级 f5d9981e
......@@ -77,8 +77,7 @@ template <
int KeysPerBlock_,
bool SingleValueIteration_,
GroupScheduleMode GroupScheduleMode_,
bool AddMask,
bool MaskBroadcastRow>
bool AddMask>
struct DefaultFMHAGrouped {
using scalar_t = scalar_t_;
using accum_t = float;
......@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped {
using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_;
static bool const kAddMask = AddMask;
static bool const kMaskBroadcastRow = MaskBroadcastRow;
static bool const kSingleValueIteration = SingleValueIteration_;
static int const kKeysPerBlock = KeysPerBlock_;
static bool const kMaskIsAligned = maskIsAligned_;
......@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped {
SingleValueIteration_,
GroupScheduleMode_,
AddMask,
MaskBroadcastRow,
maskIsAligned_>;
};
......
......@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K
/// perform
bool kAddMask,
// This is quite faster when mask need broadcast at row axis
bool kMaskBroadcastRow,
bool kMaskIsAligned>
struct FMHAGrouped {
public:
......@@ -191,6 +190,7 @@ struct FMHAGrouped {
// Whether causal masking is to be performed
bool causal;
bool mask_broadcast_head;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
......@@ -224,6 +224,7 @@ struct FMHAGrouped {
kElementV(0),
kElementO(0),
causal(false),
mask_broadcast_head(true),
host_problem_sizes(nullptr) {}
/// Ctor
......@@ -250,6 +251,7 @@ struct FMHAGrouped {
int64_t kElementV,
int64_t kElementO,
bool causal,
bool mask_broadcast_head,
ElementAccumulator scale,
GemmCoord *host_problem_sizes = nullptr)
: problem_sizes0(problem_sizes0),
......@@ -276,6 +278,7 @@ struct FMHAGrouped {
kElementV(kElementV),
kElementO(kElementO),
causal(causal),
mask_broadcast_head(mask_broadcast_head),
scale(scale),
host_problem_sizes(host_problem_sizes) {}
......@@ -327,6 +330,7 @@ struct FMHAGrouped {
ElementAccumulator scale;
bool causal;
bool mask_broadcast_head;
//
// Methods
......@@ -352,6 +356,7 @@ struct FMHAGrouped {
kElementV(0),
kElementO(0),
causal(false),
mask_broadcast_head(true),
scale(0) {}
explicit CUTLASS_HOST_DEVICE Params(Arguments const &args,
......@@ -384,6 +389,7 @@ struct FMHAGrouped {
kElementV(args.kElementV),
kElementO(args.kElementO),
causal(args.causal),
mask_broadcast_head(args.mask_broadcast_head),
scale(args.scale) {}
// CUTLASS_HOST_DEVICE
......@@ -704,6 +710,8 @@ struct FMHAGrouped {
// apply attention mask if applicable
if (kAddMask) {
const int mask_id =
params.mask_broadcast_head ? batch_idx : problem_idx;
accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(
params.scale, accum);
// load mask tile Bij into shared memory
......@@ -711,7 +719,7 @@ struct FMHAGrouped {
{cutlass::layout::RowMajor(params.ldm)},
// attn_mask_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
params.ptr_M + batch_idx * params.kElementM +
params.ptr_M + mask_id * params.kElementM +
TileParams::query_start(threadblock_idx) * params.ldm +
iter_key_start,
{problem_size_0_m, problem_size_0_n},
......
......@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params &params, const phi::GPUContext& ct
params.ElementV,
params.ElementO,
params.causal,
params.mask_broadcast_head,
params.scale,
problem_sizes1.data());
......@@ -234,7 +235,6 @@ class FwdKernel:
k: int
single_value_iter: bool
support_mask: bool = True
mask_broadcast: bool = False
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
......@@ -249,7 +249,6 @@ class FwdKernel:
0 if self.single_value_iter else 1,
self.q,
0 if self.mask_aligned else 1,
0 if self.mask_broadcast else 1,
)
@property
......@@ -264,10 +263,6 @@ class FwdKernel:
def _mask_support_suffix(self) -> str:
return "sm" if self.support_mask else "usm"
@property
def _mask_broadcast_suffix(self) -> str:
return "mb" if self.mask_broadcast else "mnb"
@property
def _single_value_suffix(self) -> str:
return "rf" if self.single_value_iter else "urf"
......@@ -289,7 +284,6 @@ class FwdKernel:
"true" if self.single_value_iter else "false",
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
"true" if self.support_mask else "false",
"false",
]
)
return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>"
......@@ -297,7 +291,7 @@ class FwdKernel:
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._mask_broadcast_suffix}_{self._single_value_suffix}_{self.q}x{self.k}"
return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._single_value_suffix}_{self.q}x{self.k}"
@property
def cpp_impl(self) -> str:
......@@ -336,7 +330,6 @@ class FwdKernel:
single_value_iter=single_value_iter,
support_mask=support_mask,
mask_aligned=mask_aligned,
mask_broadcast=False,
)
)
return kernels
......@@ -490,7 +483,7 @@ struct Params {{
int64_t ElementO;
bool causal;
bool mask_broadcast_row;
bool mask_broadcast_head;
}};
__global__ static void get_problem_sizes(const int* seq_lens,
......
......@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel(
if (mask) {
// [B, 1, S, D]
auto mask_tensor = mask.get();
int64_t mask_num_heads = mask_tensor.dims()[1];
params.ldm = mask_tensor.dims()[3];
params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3];
params.mask_ptr = mask_tensor.data();
params.mask_broadcast_row = false;
params.mask_broadcast_head = mask_num_heads == 1 ? true : false;
}
bool kernel_launched = false;
......@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel(
if (!mask && KernelType::kAddMask) {
return;
}
if (KernelType::kMaskBroadcastRow) {
// not support mask_broad_cast
return;
}
if (mask && reinterpret_cast<uintptr_t>(params.mask_ptr) % 16 == 0 &&
params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) {
return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册