未验证 提交 6839a7b9 编写于 作者: L lzy 提交者: GitHub

make variable_length_memory_efficient_attention supports mask_broadcast_heads (#56673)

上级 f5d9981e
...@@ -77,8 +77,7 @@ template < ...@@ -77,8 +77,7 @@ template <
int KeysPerBlock_, int KeysPerBlock_,
bool SingleValueIteration_, bool SingleValueIteration_,
GroupScheduleMode GroupScheduleMode_, GroupScheduleMode GroupScheduleMode_,
bool AddMask, bool AddMask>
bool MaskBroadcastRow>
struct DefaultFMHAGrouped { struct DefaultFMHAGrouped {
using scalar_t = scalar_t_; using scalar_t = scalar_t_;
using accum_t = float; using accum_t = float;
...@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped { ...@@ -92,7 +91,6 @@ struct DefaultFMHAGrouped {
using ArchTag = ArchTag_; using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_; static bool const kIsAligned = isAligned_;
static bool const kAddMask = AddMask; static bool const kAddMask = AddMask;
static bool const kMaskBroadcastRow = MaskBroadcastRow;
static bool const kSingleValueIteration = SingleValueIteration_; static bool const kSingleValueIteration = SingleValueIteration_;
static int const kKeysPerBlock = KeysPerBlock_; static int const kKeysPerBlock = KeysPerBlock_;
static bool const kMaskIsAligned = maskIsAligned_; static bool const kMaskIsAligned = maskIsAligned_;
...@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped { ...@@ -288,7 +286,6 @@ struct DefaultFMHAGrouped {
SingleValueIteration_, SingleValueIteration_,
GroupScheduleMode_, GroupScheduleMode_,
AddMask, AddMask,
MaskBroadcastRow,
maskIsAligned_>; maskIsAligned_>;
}; };
......
...@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K ...@@ -72,7 +72,6 @@ template <typename MM0_, ///! Structure for computing P = Q @ K
/// perform /// perform
bool kAddMask, bool kAddMask,
// This is quite faster when mask need broadcast at row axis // This is quite faster when mask need broadcast at row axis
bool kMaskBroadcastRow,
bool kMaskIsAligned> bool kMaskIsAligned>
struct FMHAGrouped { struct FMHAGrouped {
public: public:
...@@ -191,6 +190,7 @@ struct FMHAGrouped { ...@@ -191,6 +190,7 @@ struct FMHAGrouped {
// Whether causal masking is to be performed // Whether causal masking is to be performed
bool causal; bool causal;
bool mask_broadcast_head;
// Only used by device-level operator // Only used by device-level operator
GemmCoord *host_problem_sizes; GemmCoord *host_problem_sizes;
...@@ -224,6 +224,7 @@ struct FMHAGrouped { ...@@ -224,6 +224,7 @@ struct FMHAGrouped {
kElementV(0), kElementV(0),
kElementO(0), kElementO(0),
causal(false), causal(false),
mask_broadcast_head(true),
host_problem_sizes(nullptr) {} host_problem_sizes(nullptr) {}
/// Ctor /// Ctor
...@@ -250,6 +251,7 @@ struct FMHAGrouped { ...@@ -250,6 +251,7 @@ struct FMHAGrouped {
int64_t kElementV, int64_t kElementV,
int64_t kElementO, int64_t kElementO,
bool causal, bool causal,
bool mask_broadcast_head,
ElementAccumulator scale, ElementAccumulator scale,
GemmCoord *host_problem_sizes = nullptr) GemmCoord *host_problem_sizes = nullptr)
: problem_sizes0(problem_sizes0), : problem_sizes0(problem_sizes0),
...@@ -276,6 +278,7 @@ struct FMHAGrouped { ...@@ -276,6 +278,7 @@ struct FMHAGrouped {
kElementV(kElementV), kElementV(kElementV),
kElementO(kElementO), kElementO(kElementO),
causal(causal), causal(causal),
mask_broadcast_head(mask_broadcast_head),
scale(scale), scale(scale),
host_problem_sizes(host_problem_sizes) {} host_problem_sizes(host_problem_sizes) {}
...@@ -327,6 +330,7 @@ struct FMHAGrouped { ...@@ -327,6 +330,7 @@ struct FMHAGrouped {
ElementAccumulator scale; ElementAccumulator scale;
bool causal; bool causal;
bool mask_broadcast_head;
// //
// Methods // Methods
...@@ -352,6 +356,7 @@ struct FMHAGrouped { ...@@ -352,6 +356,7 @@ struct FMHAGrouped {
kElementV(0), kElementV(0),
kElementO(0), kElementO(0),
causal(false), causal(false),
mask_broadcast_head(true),
scale(0) {} scale(0) {}
explicit CUTLASS_HOST_DEVICE Params(Arguments const &args, explicit CUTLASS_HOST_DEVICE Params(Arguments const &args,
...@@ -384,6 +389,7 @@ struct FMHAGrouped { ...@@ -384,6 +389,7 @@ struct FMHAGrouped {
kElementV(args.kElementV), kElementV(args.kElementV),
kElementO(args.kElementO), kElementO(args.kElementO),
causal(args.causal), causal(args.causal),
mask_broadcast_head(args.mask_broadcast_head),
scale(args.scale) {} scale(args.scale) {}
// CUTLASS_HOST_DEVICE // CUTLASS_HOST_DEVICE
...@@ -704,6 +710,8 @@ struct FMHAGrouped { ...@@ -704,6 +710,8 @@ struct FMHAGrouped {
// apply attention mask if applicable // apply attention mask if applicable
if (kAddMask) { if (kAddMask) {
const int mask_id =
params.mask_broadcast_head ? batch_idx : problem_idx;
accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()( accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(
params.scale, accum); params.scale, accum);
// load mask tile Bij into shared memory // load mask tile Bij into shared memory
...@@ -711,7 +719,7 @@ struct FMHAGrouped { ...@@ -711,7 +719,7 @@ struct FMHAGrouped {
{cutlass::layout::RowMajor(params.ldm)}, {cutlass::layout::RowMajor(params.ldm)},
// attn_mask_pointer points to matrix of size (n_queries, n_keys) // attn_mask_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id // for the relevant batch_id and head_id
params.ptr_M + batch_idx * params.kElementM + params.ptr_M + mask_id * params.kElementM +
TileParams::query_start(threadblock_idx) * params.ldm + TileParams::query_start(threadblock_idx) * params.ldm +
iter_key_start, iter_key_start,
{problem_size_0_m, problem_size_0_n}, {problem_size_0_m, problem_size_0_n},
......
...@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params &params, const phi::GPUContext& ct ...@@ -200,6 +200,7 @@ void {NAME}({CPP_CLASS} default_fmha, Params &params, const phi::GPUContext& ct
params.ElementV, params.ElementV,
params.ElementO, params.ElementO,
params.causal, params.causal,
params.mask_broadcast_head,
params.scale, params.scale,
problem_sizes1.data()); problem_sizes1.data());
...@@ -234,7 +235,6 @@ class FwdKernel: ...@@ -234,7 +235,6 @@ class FwdKernel:
k: int k: int
single_value_iter: bool single_value_iter: bool
support_mask: bool = True support_mask: bool = True
mask_broadcast: bool = False
dispatch_cond: Optional[str] = None dispatch_cond: Optional[str] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -249,7 +249,6 @@ class FwdKernel: ...@@ -249,7 +249,6 @@ class FwdKernel:
0 if self.single_value_iter else 1, 0 if self.single_value_iter else 1,
self.q, self.q,
0 if self.mask_aligned else 1, 0 if self.mask_aligned else 1,
0 if self.mask_broadcast else 1,
) )
@property @property
...@@ -264,10 +263,6 @@ class FwdKernel: ...@@ -264,10 +263,6 @@ class FwdKernel:
def _mask_support_suffix(self) -> str: def _mask_support_suffix(self) -> str:
return "sm" if self.support_mask else "usm" return "sm" if self.support_mask else "usm"
@property
def _mask_broadcast_suffix(self) -> str:
return "mb" if self.mask_broadcast else "mnb"
@property @property
def _single_value_suffix(self) -> str: def _single_value_suffix(self) -> str:
return "rf" if self.single_value_iter else "urf" return "rf" if self.single_value_iter else "urf"
...@@ -289,7 +284,6 @@ class FwdKernel: ...@@ -289,7 +284,6 @@ class FwdKernel:
"true" if self.single_value_iter else "false", "true" if self.single_value_iter else "false",
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
"true" if self.support_mask else "false", "true" if self.support_mask else "false",
"false",
] ]
) )
return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>" return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>"
...@@ -297,7 +291,7 @@ class FwdKernel: ...@@ -297,7 +291,7 @@ class FwdKernel:
@property @property
def impl_group(self) -> str: def impl_group(self) -> str:
# Maps to file which will contain the implementation # Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._mask_broadcast_suffix}_{self._single_value_suffix}_{self.q}x{self.k}" return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._single_value_suffix}_{self.q}x{self.k}"
@property @property
def cpp_impl(self) -> str: def cpp_impl(self) -> str:
...@@ -336,7 +330,6 @@ class FwdKernel: ...@@ -336,7 +330,6 @@ class FwdKernel:
single_value_iter=single_value_iter, single_value_iter=single_value_iter,
support_mask=support_mask, support_mask=support_mask,
mask_aligned=mask_aligned, mask_aligned=mask_aligned,
mask_broadcast=False,
) )
) )
return kernels return kernels
...@@ -490,7 +483,7 @@ struct Params {{ ...@@ -490,7 +483,7 @@ struct Params {{
int64_t ElementO; int64_t ElementO;
bool causal; bool causal;
bool mask_broadcast_row; bool mask_broadcast_head;
}}; }};
__global__ static void get_problem_sizes(const int* seq_lens, __global__ static void get_problem_sizes(const int* seq_lens,
......
...@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel( ...@@ -65,10 +65,11 @@ void MultiHeadAttentionVariableForwardKernel(
if (mask) { if (mask) {
// [B, 1, S, D] // [B, 1, S, D]
auto mask_tensor = mask.get(); auto mask_tensor = mask.get();
int64_t mask_num_heads = mask_tensor.dims()[1];
params.ldm = mask_tensor.dims()[3]; params.ldm = mask_tensor.dims()[3];
params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3]; params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3];
params.mask_ptr = mask_tensor.data(); params.mask_ptr = mask_tensor.data();
params.mask_broadcast_row = false; params.mask_broadcast_head = mask_num_heads == 1 ? true : false;
} }
bool kernel_launched = false; bool kernel_launched = false;
...@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel( ...@@ -84,10 +85,6 @@ void MultiHeadAttentionVariableForwardKernel(
if (!mask && KernelType::kAddMask) { if (!mask && KernelType::kAddMask) {
return; return;
} }
if (KernelType::kMaskBroadcastRow) {
// not support mask_broad_cast
return;
}
if (mask && reinterpret_cast<uintptr_t>(params.mask_ptr) % 16 == 0 && if (mask && reinterpret_cast<uintptr_t>(params.mask_ptr) % 16 == 0 &&
params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) { params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) {
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册