未验证 提交 4036c937 编写于 作者: L lzy 提交者: GitHub

Add variable_length_memory_efficient_attention (#55400)

* add variable_length_memory_efficient_attention
* update variable_length_memory_efficient_attention unittest
* update variable_length_mem_eff_attn's docs and unittest
* update variable_length_mem_eff_attn's docs
* Update test_variable_length_memory_efficient_attention.py
* Update variable_length_memory_efficient_attention.cu
* fix codestyle
* fix variable_length_fmha's docs and unittest
* fix variable_length_fmha's docs
上级 b561a05e
......@@ -2682,6 +2682,16 @@
skip_transform : found_infinite
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : variable_length_memory_efficient_attention
args : (Tensor query, Tensor key, Tensor value, Tensor seq_lens, Tensor kv_seq_lens, Tensor mask, float scale, bool causal)
output : Tensor
infer_meta :
func : VariableLengthMemoryEfficientAttentionInferMeta
kernel :
func : variable_length_memory_efficient_attention
data_type : query
optional : mask
- op : view_dtype
args : (Tensor input, DataType dtype)
output : Tensor(out)
......
......@@ -2570,6 +2570,81 @@ void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
seed_and_offset->set_dtype(phi::DataType::INT64);
}
void VariableLengthMemoryEfficientAttentionInferMeta(
const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& seq_lens,
const MetaTensor& kv_seq_lens,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(
query.dims().size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%s)",
query.dims().size()));
PADDLE_ENFORCE_EQ(
key.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
key.dims().size()));
PADDLE_ENFORCE_EQ(
value.dims().size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%s)",
value.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_num_head = query.dims()[1];
const int64_t query_seq_length = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_num_head = key.dims()[1];
const int64_t key_seq_length = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_num_head = value.dims()[1];
const int64_t value_seq_length = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
PADDLE_ENFORCE_EQ(
((query_batch_size == key_batch_size) &&
(key_batch_size == value_batch_size)),
true,
phi::errors::InvalidArgument(
"The batch size of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
phi::errors::InvalidArgument(
"The head size of Query, Key should be equal."));
PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length,
true,
phi::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
std::vector<int64_t> out_dims(
{query_batch_size, query_num_head, query_seq_length, value_head_size});
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype(query.dtype());
out->set_layout(query.layout());
}
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size();
......
......@@ -471,6 +471,17 @@ void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset);
void VariableLengthMemoryEfficientAttentionInferMeta(
const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& seq_lens,
const MetaTensor& kv_seq_lens,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out);
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);
......
......@@ -69,6 +69,13 @@ if(WITH_CUTLASS)
--cuda_arch "${NVCC_ARCH_BIN}"
RESULT_VARIABLE memory_efficient_attention_gen_res)
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py
--cuda_arch "${NVCC_ARCH_BIN}"
RESULT_VARIABLE memory_efficient_attention_gen_res)
if(NOT memory_efficient_attention_gen_res EQUAL 0)
message(
FATAL_ERROR
......@@ -79,9 +86,11 @@ if(WITH_CUTLASS)
file(
GLOB cutlass_cu
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"fusion/cutlass/conv2d/generated/*.cu" "fusion/cutlass/conv2d/*.cu"
"fusion/cutlass/conv2d/generated/*.cu"
"fusion/cutlass/conv2d/*.cu"
"fusion/cutlass/*.cu"
"fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu")
"fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu"
"fusion/cutlass/memory_efficient_attention/autogen_variable/impl/*.cu")
list(APPEND kernel_cu ${cutlass_cu})
endif()
......
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix
multiply-add with the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major
outputs are accommodated by exchanging A and B operands and assuming
transposed layouts. Partial specializations here choose
'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "./gemm_kernel_utils.h"
#include "gemm/attention_scaling_coefs_updater.h"
#include "gemm/find_default_mma.h"
#include "gemm/fmha_grouped.h"
#include "gemm/mma_from_smem.h"
#include "transform/tile_smem_loader.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag_,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
bool maskIsAligned_,
int QueriesPerBlock_,
int KeysPerBlock_,
bool SingleValueIteration_,
GroupScheduleMode GroupScheduleMode_,
bool AddMask,
bool MaskBroadcastRow>
struct DefaultFMHAGrouped {
using scalar_t = scalar_t_;
using accum_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_;
static bool const kAddMask = AddMask;
static bool const kMaskBroadcastRow = MaskBroadcastRow;
static bool const kSingleValueIteration = SingleValueIteration_;
static int const kKeysPerBlock = KeysPerBlock_;
static bool const kMaskIsAligned = maskIsAligned_;
static int const kWarpSize = 32;
static int const kNumWarpsPerBlock =
QueriesPerBlock_ * KeysPerBlock_ / (kWarpSize * kWarpSize);
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = gemm_kernel_utils::DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = scalar_t;
using ElementAccumulator = accum_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
ElementA,
ElementB,
ElementC,
ElementAccumulator>;
static int const kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static int const kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<QueriesPerBlock_, KeysPerBlock_, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
static int const kStages = DefaultConfig::kStages;
using Operator = typename GemmType::Operator;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
ElementAccumulator,
kWarpSize>::Updater;
static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, "");
// used for efficient load of mask_ tile Bij from global to shared memory
using MaskLoader = TileSmemLoader<
scalar_t,
cutlass::MatrixShape<QueriesPerBlock_, KeysPerBlock_>,
MmaCore::kThreads,
// input restriction: kv_len has to be a multiple of this value
kMaskIsAligned ? 128 / cutlass::sizeof_bits<scalar_t>::value : 1>;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/*
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = typename MM0::GemmType;
using OpClass = typename GemmType::OpClass;
using ElementA = scalar_t;
using ElementB = scalar_t;
using ElementC = output_accum_t;
using ElementAccumulator = accum_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
ElementA,
ElementB,
ElementC,
ElementAccumulator>;
static int const kAlignmentA = DefaultConfig::kAlignmentA;
static int const kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = typename MM0::ThreadblockShape;
using WarpShape = typename MM0::WarpShape;
using InstructionShape = typename MM0::InstructionShape;
using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp;
static int const kStages = DefaultConfig::kStages;
using Operator = typename GemmType::Operator;
using ThreadblockSwizzle = void; // Swizzling is unused
static bool const kSplitKSerial = false;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
false>;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(WarpCount::kCount == kNumWarpsPerBlock, "");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
/// Define the kernel in terms of the default kernel
using FMHAKernel = kernel::FMHAGrouped<MM0,
MM1,
scalar_t,
accum_t,
output_t,
output_accum_t,
SingleValueIteration_,
GroupScheduleMode_,
AddMask,
MaskBroadcastRow,
maskIsAligned_>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "../gemm_kernel_utils.h"
#include "../kernel_forward.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
We have multiple implementations, because each configuration has a different way
of iterating in the accumulators.
*/
template <typename BASE, typename T, typename accum_t, int kWarpSize>
struct RegisterOps {
template <int kQueriesPerBlock,
int kNumWarpsPerBlock,
int kWarpN,
bool kFullColumns,
bool kIsFirst,
bool kKeepOutputInRF>
CUTLASS_DEVICE static void update(
typename T::Fragment& frag_o, // NOLINT
typename T::Fragment& frag, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& mi, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock * kWarpN>&
addition_storage, // NOLINT
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename T::TensorCoord const& tile_offset,
float scaling) {
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
BASE::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<typename T::Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
BASE::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
BASE::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (BASE::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
addition_storage[accum_m + kQueriesPerBlock *
tile_offset.column()] = total_row;
// atomicAdd(&s_prime[accum_m], total_row);
}
});
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
total_row = s_prime[id];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWarpN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSm80
: RegisterOps<AttentionScalingCoefsUpdaterSm80<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess +
tile_offset.column() * Shape::kColumn);
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
row * kRowsPerTile + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
col + lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) {
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterVolta
: RegisterOps<AttentionScalingCoefsUpdaterVolta<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<
cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n =
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(
accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) {
static_assert(cutlass::platform::is_same<Element, float>::value,
"update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2 +
lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2 + n +
lane_offset.column();
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterSimt
: RegisterOps<AttentionScalingCoefsUpdaterSimt<T, accum_t, kWarpSize>,
T,
accum_t,
kWarpSize> {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) {
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n =
mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx =
n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
static_assert(cutlass::platform::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset =
lane_layout.inverse(lane_id) *
cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
Policy::LaneMmaShape::kN);
return lane_offset +
tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaSimtTileIterator<S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>,
accum_t,
kWarpSize> {
using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Updater =
AttentionScalingCoefsUpdaterSimt<Iterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using Iterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Updater =
AttentionScalingCoefsUpdaterVolta<Iterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template <typename S1,
typename S2,
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultAttentionScalingCoefsUpdater<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>,
accum_t,
kWarpSize> {
using Iterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Updater =
AttentionScalingCoefsUpdaterSm80<Iterator, accum_t, kWarpSize>;
};
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Base device-level grouped kernel.
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename BaseKernel_>
class BaseGrouped {
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator =
typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(
const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) {
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i) {
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
cudaError_t cuda_error =
cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error "
<< cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const& args,
int32_t tile_count,
void* workspace) {
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(
args.host_problem_sizes,
args.problem_count,
args.threadblock_count,
reinterpret_cast<void*>(host_workspace.data()));
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, const std::vector<size_t>& indices) {
// For now, simply create a copy of the data and then copy over to the
// original.
std::vector<T> copy(indices.size());
for (int i = 0; i < indices.size(); ++i) {
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseGrouped() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args) {
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) {
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const& args) {
if (args.host_problem_sizes == nullptr) {
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args) {
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
return BaseKernel::ProblemVisitor::get_workspace_size(
args.host_problem_sizes, args.problem_count, args.threadblock_count);
} else {
return 0;
}
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args) {
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()");
int smem_size =
static_cast<int>(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<BaseKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error "
<< cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result =
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
Kernel<BaseKernel>,
BaseKernel::kThreadCount,
smem_size);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count,
cutlass::gemm::GemmCoord* problem_sizes_ptr,
int64_t* lda_host_ptr,
int64_t* ldb_host_ptr,
int64_t* ldc_host_ptr,
int64_t* ldd_host_ptr,
int64_t* offset_A_ptr,
int64_t* offset_B_ptr,
int64_t* offset_C_ptr,
int64_t* offset_D_ptr) {
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(),
indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) {
return problem_sizes_ptr[i].k() >
problem_sizes_ptr[j].k();
});
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr,
int problem_count = 0,
int available_sm_count = -1) {
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
// printf("custom base\n");
static cudaDeviceProp properties;
static bool count = true;
if (count) {
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error "
<< cudaGetErrorString(result));
return 0;
}
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error "
<< cudaGetErrorString(result));
return 0;
}
}
count = false;
bool override_sm_count =
(available_sm_count < 0 ||
available_sm_count > properties.multiProcessorCount);
if (override_sm_count) {
available_sm_count = properties.multiProcessorCount;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0) {
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0) {
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return
// total_tiles unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count) {
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating
// through problem sizes to determine that they have no work to do. This
// competes for cycles with those threadblocks that are assigned tiles to
// compute.
return min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args,
void* workspace = nullptr,
cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace "
<< workspace
<< ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess) {
return status;
}
params_ = typename BaseKernel::Params(args, workspace, tile_count);
} else {
params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size =
static_cast<int>(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result =
cudaFuncSetAttribute(Kernel<BaseKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr) {
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess) {
return status;
}
params_.update(args, workspace, tile_count);
} else {
params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
//
// Configure grid and block dimensions
//
if (!params_.problem_visitor.problem_count) {
return Status::kSuccess;
}
dim3 grid(params_.threadblock_count, 1, 1);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size =
static_cast<int>(sizeof(typename BaseKernel::SharedStorage));
//
// Launch kernel
//
// Launch
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" grid launch failed with error "
<< cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) { return run(stream); }
/// Initializes and runs the kernel.
Status operator()(Arguments const& args,
void* workspace,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Scheduler for grouped FMHA
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Helper for correctly representing problem sizes in grouped kernels
template <typename ThreadblockShape>
struct FMHAGroupedProblemSizeHelper {
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(
const cutlass::gemm::GemmCoord &problem) {
// FMHA only partitions tiles across the M dimension.
return cutlass::gemm::GemmCoord(
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
1,
1);
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(
cutlass::gemm::GemmCoord &problem) { // NOLINT
} // NOLINT
CUTLASS_HOST_DEVICE
static int32_t tile_count(const cutlass::gemm::GemmCoord &grid) {
return grid.m() * grid.n();
}
};
} // namespace detail
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct FMHAGroupedProblemVisitor
: public GroupedProblemVisitor<
detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
using ProblemSizeHelper =
detail::FMHAGroupedProblemSizeHelper<ThreadblockShape>;
using Base = GroupedProblemVisitor<ProblemSizeHelper,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount>;
using BaseParams = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
struct Params {
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: problem_sizes0(nullptr),
problem_sizes1(nullptr),
problem_count(0),
workspace(nullptr),
tile_count(0) {}
/// Ctor
CUTLASS_HOST_DEVICE
Params(cutlass::gemm::GemmCoord const *problem_sizes0,
cutlass::gemm::GemmCoord const *problem_sizes1,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0)
: problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count) {}
/// Convert the FMHA-specific parameters to those used by the base class
CUTLASS_HOST_DEVICE
BaseParams to_base() const {
return BaseParams( // Set problem_sizes as problem_sizes1 because these
// determine shape of the final output of FMHA
problem_sizes1,
problem_count,
workspace,
tile_count);
}
};
//
// Methods
//
CUTLASS_DEVICE
FMHAGroupedProblemVisitor(Params const &params_,
SharedStorage &shared_storage_, // NOLINT
int32_t block_idx)
: Base(params_.to_base(), shared_storage_, block_idx),
problem_sizes0(params_.problem_sizes0),
problem_sizes1(params_.problem_sizes1) {}
/// Returns the problem size 0 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size0() const {
GemmCoord problem = problem_sizes0[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
/// Returns the problem size 1 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size1() const {
GemmCoord problem = problem_sizes1[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Device-level grouped GEMM.
*/
#pragma once
#include "./base_grouped.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename GemmKernel_>
class GemmGrouped : public BaseGrouped<GemmKernel_> {
public:
using GemmKernel = GemmKernel_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Generates combination of kernels - implementations and registry
# Kernels are ordered (see `sort_index`), and when dispatching,
# we select the first kernel in the list that supports the inputs
import argparse
import collections
import itertools
import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
DEFAULT_ARCH = [50, 70, 75, 80]
MAX_ARCH = 90
ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION"
assert sorted(DEFAULT_ARCH) == DEFAULT_ARCH
def find_arch_range(min_arch, max_arch):
assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH
assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH
assert min_arch <= max_arch
n = len(DEFAULT_ARCH)
start_idx = n - 1
for i in range(n - 1):
if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]:
start_idx = i
break
end_idx = n
for i in range(n - 1):
if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]:
end_idx = i + 1
return DEFAULT_ARCH[start_idx:end_idx]
def find_max_arch(arch):
arch = sorted(arch)
idx = DEFAULT_ARCH.index(arch[-1])
if idx == len(DEFAULT_ARCH) - 1:
return MAX_ARCH
else:
return DEFAULT_ARCH[idx + 1]
def convert_to_arch_list(arch):
arch = arch.lower().strip()
if arch == "all":
return DEFAULT_ARCH
arch = [int(s.strip()) for s in arch.split(';') if s.strip()]
arch = list(set(arch))
arch.sort()
return find_arch_range(arch[0], arch[-1])
def parse_args():
parser = argparse.ArgumentParser(
description="The argument for generating the memory efficient kernels."
)
parser.add_argument(
"--dst_path",
type=str,
default=str(Path(__file__).parent),
help="The destination path to save the generated files.",
)
parser.add_argument(
"--cuda_arch",
type=convert_to_arch_list,
default=convert_to_arch_list("All"),
help="The CUDA architecture to be generated.",
)
args = parser.parse_args()
args.max_arch = find_max_arch(args.cuda_arch)
return args
args = parse_args()
DTYPES = {
"f32": "float",
"f16": "cutlass::half_t",
"bf16": "cutlass::bfloat16_t",
}
SM = args.cuda_arch
KERNEL_IMPL_TEMPLATE = """
void {NAME}({CPP_CLASS} default_fmha, Params &params, const phi::GPUContext& ctx) {{
using AttentionKernel = typename decltype(default_fmha)::FMHAKernel;
using FMHA = cutlass::gemm::device::GemmGrouped<AttentionKernel>;
using scalar_t = typename FMHA::GemmKernel::scalar_t;
using accum_t = typename FMHA::GemmKernel::accum_t;
using output_t = typename FMHA::GemmKernel::output_t;
using output_accum_t = typename FMHA::GemmKernel::output_accum_t;
using ElementQ = scalar_t;
using ElementK = scalar_t;
using ElementP = accum_t;
using ElementM = scalar_t;
using ElementAccumulator = accum_t;
using ElementV = scalar_t;
using ElementO = output_t;
using ElementOAccum = output_accum_t;
int problem_count = params.num_batches * params.num_heads;
std::vector<GemmCoord> problem_sizes1;
problem_sizes1.reserve(problem_count);
phi::Allocator::AllocationPtr problem_sizes_device0{{nullptr}};
phi::Allocator::AllocationPtr problem_sizes_device1{{nullptr}};
problem_sizes_device0 = phi::memory_utils::Alloc(
ctx.GetPlace(),
problem_count * sizeof(GemmCoord),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
problem_sizes_device1 = phi::memory_utils::Alloc(
ctx.GetPlace(),
problem_count * sizeof(GemmCoord),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
GemmCoord* problem0_device =
reinterpret_cast<GemmCoord*>(problem_sizes_device0->ptr());
GemmCoord* problem1_device =
reinterpret_cast<GemmCoord*>(problem_sizes_device1->ptr());
get_problem_sizes<<<params.num_batches, params.num_heads, 0, ctx.stream()>>>(
params.seq_lens,
params.kv_seq_lens,
problem0_device,
problem1_device,
params.num_batches,
params.num_heads,
params.head_size,
params.value_head_size);
phi::memory_utils::Copy(phi::CPUPlace(),
problem_sizes1.data(),
ctx.GetPlace(),
problem1_device,
sizeof(GemmCoord) * problem_count,
ctx.stream());
if (AttentionKernel::kNeedsOutputAccumulatorBuffer) {{
const int64_t output_size = params.num_batches * params.num_heads *
params.query_seq_len * params.value_head_size;
phi::Allocator::AllocationPtr tmp_output_accum_buffer_ptr{{nullptr}};
tmp_output_accum_buffer_ptr = phi::memory_utils::Alloc(
ctx.GetPlace(),
output_size * sizeof(ElementOAccum),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
params.output_accum_ptr = tmp_output_accum_buffer_ptr->ptr();
}}
int threadblock_count =
FMHA::sufficient(problem_sizes1.data(), problem_count);
typename FMHA::Arguments args(
problem0_device,
problem1_device,
problem_count,
threadblock_count,
params.num_heads,
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.query_ptr)),
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.key_ptr)),
params.mask_ptr
? const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.mask_ptr))
: nullptr,
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.value_ptr)),
reinterpret_cast<scalar_t*>(params.output_ptr),
AttentionKernel::kNeedsOutputAccumulatorBuffer
? reinterpret_cast<output_accum_t*>(params.output_accum_ptr)
: nullptr,
params.ldq,
params.ldk,
params.ldm,
params.ldv,
params.ldo,
params.ElementQ,
params.ElementK,
params.ElementM,
params.ElementV,
params.ElementO,
params.causal,
params.scale,
problem_sizes1.data());
FMHA fmha;
cutlass::Status status;
size_t workspace_size = fmha.get_workspace_size(args);
phi::DenseTensor workspace;
workspace.Resize(phi::make_ddim({{static_cast<int64_t>(workspace_size)}}));
ctx.template Alloc<uint8_t>(&workspace);
status = fmha.initialize(args, workspace.data<uint8_t>());
if (status != cutlass::Status::kSuccess) {{
PADDLE_THROW(phi::errors::Unimplemented(
"Failed to initialize CUTLASS Grouped FMHA kernel."));
}}
status = fmha.run(ctx.stream());
if (status != cutlass::Status::kSuccess) {{
PADDLE_THROW(phi::errors::Unimplemented(
"Failed to run CUTLASS Grouped FMHA kernel."));
}}
}}
"""
@dataclass(order=True)
class FwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
aligned: bool
mask_aligned: bool
dtype: str
sm_range: Tuple[int, int]
q: int
k: int
single_value_iter: bool
support_mask: bool = True
mask_broadcast: bool = False
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
0 if self.support_mask else 1,
# Then keep output in RF
0 if self.single_value_iter else 1,
self.q,
0 if self.mask_aligned else 1,
0 if self.mask_broadcast else 1,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def _mask_aligned_suffix(self) -> str:
return "ma" if self.mask_aligned else "mua"
@property
def _mask_support_suffix(self) -> str:
return "sm" if self.support_mask else "usm"
@property
def _mask_broadcast_suffix(self) -> str:
return "mb" if self.mask_broadcast else "mnb"
@property
def _single_value_suffix(self) -> str:
return "rf" if self.single_value_iter else "urf"
@property
def name(self) -> str:
return f"fmha_cutlassF_variable_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{self._single_value_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_sm{self.sm_range[0]}"
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
DTYPES[self.dtype],
f"cutlass::arch::Sm{self.sm_range[0]}",
"true" if self.aligned else "false",
"true" if self.mask_aligned else "false",
str(self.q),
str(self.k),
"true" if self.single_value_iter else "false",
"cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
"true" if self.support_mask else "false",
"false",
]
)
return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._mask_broadcast_suffix}_{self._single_value_suffix}_{self.q}x{self.k}"
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class, NAME=self.name
)
@classmethod
def get_all(cls) -> List["FwdKernel"]:
kernels: List[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:] + [args.max_arch])
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
for q, k, single_value_iter in [
(32, 128, True),
(32, 128, False),
(64, 64, True),
]:
for support_mask, mask_aligned in [
(False, False),
(True, False),
(True, True),
]:
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
q=q,
k=k,
single_value_iter=single_value_iter,
support_mask=support_mask,
mask_aligned=mask_aligned,
mask_broadcast=False,
)
)
return kernels
T = TypeVar("T", bound=FwdKernel)
def write_decl_impl(
kernels: List[T], family_name: str, impl_file: str, enable_def: str
) -> None:
cpp_file_header = """// This file is auto-generated. See "generate_variable_forward_kernels.py"
"""
kernels.sort()
implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list)
cat_to_kernels: Dict[
Tuple[str, int, int], List[T]
] = collections.defaultdict(list)
dispatch_all = ""
declarations = cpp_file_header + "#pragma once\n"
declarations += f"#ifdef {enable_def}\n"
declarations += f"""#include "{impl_file}"\n"""
declarations += "namespace phi {\n"
# Declaration of kernel functions
for k in kernels:
implfile_to_kernels[k.impl_group].append(k)
cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k)
for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items():
declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n"
declarations += "\n".join(
k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels
)
dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}"
declarations += (
f"\n\ntemplate <typename T> void {dispatch_category_fn}(T cb) {{\n"
)
for k in kernels:
_call = f"cb({k.cpp_class}(), {k.name});\n"
if k.dispatch_cond is not None:
_call = f"if ({k.dispatch_cond}) {_call}"
declarations += f" {_call}"
declarations += "}\n\n"
dispatch_all += f"""
if (std::is_same<DT, {DTYPES[cat_dt]}>::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{
{dispatch_category_fn}(cb);
}}"""
declarations += f"""
template <typename PaddleT, typename T>
void dispatch_{family_name}(const ::phi::GPUContext &ctx, T cb) {{
auto cc = ctx.GetComputeCapability();
PADDLE_ENFORCE_GE(
cc,
70,
phi::errors::InvalidArgument("the Nvidia GPU's Compute Capability must be greater or equal than 70"));
using DT = typename ::phi::CutlassTrait<PaddleT>::Type;
{dispatch_all}
}}
"""
declarations += "} // namespace phi\n"
declarations += f"#endif // {enable_def}\n"
autogen_dir = Path(args.dst_path) / "autogen_variable"
os.makedirs(autogen_dir, exist_ok=True)
declaration_path = autogen_dir / f"{family_name}.h"
declaration_path.write_text(declarations)
for f, f_kernels in implfile_to_kernels.items():
impl_cu = cpp_file_header
impl_cu += f"#ifdef {enable_def}\n"
impl_cu += f"""#include "{impl_file}"\n"""
impl_cu += "namespace phi {\n"
for k in f_kernels:
impl_cu += k.cpp_impl
impl_cu += "} // namespace phi\n"
impl_cu += f"#endif // {enable_def}\n"
impl_path = autogen_dir / "impl"
os.makedirs(impl_path, exist_ok=True)
(impl_path / f"{family_name}_{f}.cu").write_text(impl_cu)
def write_main_header():
main_header_content = '''
#pragma once
#ifdef {}
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "cutlass/util/device_memory.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h"
namespace phi {{
using GemmCoord = cutlass::gemm::GemmCoord;
struct Params {{
// meta params
phi::DataType datatype;
// [bs, nh, seq_len, dh]
const void* query_ptr;
const void* key_ptr;
const void* value_ptr;
// and it can be broadcasted in axis0, 1, 2.
const void* mask_ptr = nullptr;
const int* seq_lens = nullptr;
const int* kv_seq_lens = nullptr;
// Output tensors
void* output_ptr; // [num_batches, num_heads, query_seq_len, head_size]
void* output_accum_ptr =
nullptr; // [num_batches, num_heads, query_seq_len, head_size]
// Scale
float scale;
// Dimensions/strides
int32_t num_batches;
int32_t num_heads;
int32_t query_seq_len;
int32_t key_value_seq_len;
int32_t head_size;
int32_t value_head_size;
int64_t ldq;
int64_t ldk;
int64_t ldm;
int64_t ldv;
int64_t ldo;
int64_t ElementQ;
int64_t ElementK;
int64_t ElementM;
int64_t ElementV;
int64_t ElementO;
bool causal;
bool mask_broadcast_row;
}};
__global__ static void get_problem_sizes(const int* seq_lens,
const int* kv_seq_lens,
GemmCoord* problem_sizes0,
GemmCoord* problem_sizes1,
const int bs,
const int num_head,
const int head_size,
const int value_head_size) {{
int bi = blockIdx.x;
int hi = threadIdx.x;
if (bi < bs && hi < num_head) {{
int id = bi * num_head + hi;
int m = seq_lens[bi];
int mkv = kv_seq_lens[bi];
int k0 = head_size;
int k1 = value_head_size;
GemmCoord problem0(m, mkv, k0);
GemmCoord problem1(m, k1, mkv);
problem_sizes0[id] = problem0;
problem_sizes1[id] = problem1;
}}
}}
template <typename T>
struct CutlassTrait {{
using Type = T;
}};
template <>
struct CutlassTrait<dtype::float16> {{
using Type = cutlass::half_t;
}};
template <>
struct CutlassTrait<dtype::bfloat16> {{
using Type = cutlass::bfloat16_t;
}};
template <typename T>
struct ToPhiDTypeTrait {{
private:
using NonConstT = typename std::remove_const<T>::type;
static constexpr bool kIsFP16 = std::is_same<NonConstT, cutlass::half_t>::value;
static constexpr bool kIsBF16 = std::is_same<NonConstT, cutlass::bfloat16_t>::value;
public:
using Type = typename std::conditional<kIsFP16, dtype::float16,
typename std::conditional<kIsBF16, dtype::bfloat16, NonConstT>::type>::type;
}};
}} // namespace phi
#include "./cutlass_forward.h"
#endif
'''.format(
ENABLE_MACRO
)
path = Path(args.dst_path) / "autogen_variable"
os.makedirs(path, exist_ok=True)
path = Path(path) / "memory_efficient_variable_attention.h"
path.write_text(main_header_content)
if os.path.exists(Path(args.dst_path) / "autogen_variable"):
shutil.rmtree(Path(args.dst_path) / "autogen_variable")
forward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen_variable/memory_efficient_variable_attention.h"
write_main_header()
write_decl_impl(
FwdKernel.get_all(),
"cutlass_forward",
impl_file=forward_impl,
enable_def=ENABLE_MACRO,
)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen_variable/cutlass_forward.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MultiHeadAttentionVariableForwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const DenseTensor& seq_lens,
const DenseTensor& kv_seq_lens,
const paddle::optional<DenseTensor>& mask,
const float scale,
const bool causal,
DenseTensor* output) {
ctx.template Alloc<T>(output);
Params params{};
// [B, N, S, H]
params.seq_lens = seq_lens.data<int>();
params.kv_seq_lens = kv_seq_lens.data<int>();
params.num_batches = query.dims()[0];
params.num_heads = query.dims()[1];
params.query_seq_len = query.dims()[2];
params.head_size = query.dims()[3];
params.key_value_seq_len = key.dims()[2];
params.value_head_size = value.dims()[3];
params.datatype = query.dtype();
params.query_ptr = query.data();
params.key_ptr = key.data();
params.value_ptr = value.data();
params.output_ptr = output->data();
params.ldq = params.head_size;
params.ldk = params.head_size;
params.ldv = params.value_head_size;
params.ldo = params.value_head_size;
params.ElementQ = params.query_seq_len * params.head_size;
params.ElementK = params.key_value_seq_len * params.head_size;
params.ElementV = params.key_value_seq_len * params.value_head_size;
params.ElementO = params.query_seq_len * params.value_head_size;
params.scale = scale;
params.causal = causal;
if (mask) {
// [B, 1, S, D]
auto mask_tensor = mask.get();
params.ldm = mask_tensor.dims()[3];
params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3];
params.mask_ptr = mask_tensor.data();
params.mask_broadcast_row = false;
}
bool kernel_launched = false;
auto launchKernel = [&](auto k_, auto kernel_fn) {
using KernelType = decltype(k_);
if (kernel_launched) {
return;
}
if (mask && !KernelType::kAddMask) {
return;
}
if (!mask && KernelType::kAddMask) {
return;
}
if (KernelType::kMaskBroadcastRow) {
// not support mask_broad_cast
return;
}
if (mask && reinterpret_cast<uintptr_t>(params.mask_ptr) % 16 == 0 &&
params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) {
return;
}
if (mask &&
!(reinterpret_cast<uintptr_t>(params.mask_ptr) % 16 == 0 &&
params.ldm % (16 / sizeof(T)) == 0) &&
KernelType::kMaskIsAligned) {
return;
}
if (KernelType::kSingleValueIteration &&
KernelType::kKeysPerBlock < params.value_head_size) {
return;
}
if (KernelType::kKeysPerBlock == 64 && params.value_head_size > 64) {
return;
}
if (params.head_size % KernelType::MM0::kAlignmentA) {
return;
}
kernel_launched = true;
kernel_fn(k_, params, ctx);
};
dispatch_cutlass_forward<T, decltype(launchKernel)>(ctx, launchKernel);
PADDLE_ENFORCE_EQ(
kernel_launched,
true,
phi::errors::InvalidArgument("the kernel should not be launched"));
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(variable_length_memory_efficient_attention,
GPU,
ALL_LAYOUT,
phi::fusion::MultiHeadAttentionVariableForwardKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(3).SetDataType(phi::DataType::INT32);
}
......@@ -25,6 +25,9 @@ from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_gate_attention import fused_gate_attention
from .fused_rotary_position_embedding import fused_rotary_position_embedding
from .variable_length_memory_efficient_attention import (
variable_length_memory_efficient_attention,
)
from .rms_norm import rms_norm
__all__ = [
......@@ -38,5 +41,6 @@ __all__ = [
'fused_ec_moe',
'fused_dropout_add',
'fused_rotary_position_embedding',
'variable_length_memory_efficient_attention',
"rms_norm",
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following codes are from https://github.com/facebookresearch/xformers
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode
def variable_length_memory_efficient_attention(
query,
key,
value,
seq_lens,
kv_seq_lens,
mask=None,
scale=None,
causal=False,
):
"""
Cutlass Memory Efficient Variable Attention.
This method requires SM_ARCH in sm70, sm75, sm80.
Args:
query (Tensor): The Query Tensor. Its shape is [batchsize, seq_len, num_head, head_size].
key (Tensor): The Key Tensor. Its shape is [batchsize, seq_len, num_head, head_size].
value (Tensor): The Value Tensor. Its shape is [batchsize, seq_len, num_head, head_size].
seq_lens (Tensor): The sequence lengths of the sequences in the batch, used to index query. Its shape is [batchsize, 1].
kv_seq_lens (Tensor): The sequence lengths of the sequences in the batch, used to index key and value. Its shape is [batchsize, 1].
mask (Tensor): The Mask Tensor. Its shape is [batchsize, 1, query_seq_len, key_seq_len].
scale (Float): The attention matrix's scale. Default is sqrt(1.0 / head_size).
causal (Bool): Whether causal masking is used or not. Default is False.
Returns:
Tensor: the output Tensor.
Examples:
.. code-block:: python
# required: gpu
import math
import paddle
from paddle.incubate.nn.functional import variable_length_memory_efficient_attention
batch = 1
num_head = 8
seq_len = 256
head_size = 32
dtype = paddle.float16
query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32')
mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype)
scale = float(1.0 / math.sqrt(head_size))
def naive_attention_impl(query, key, value, mask, scale):
qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(softmax_result, value)
return result
out = naive_attention_impl(query, key, value, mask, scale)
# equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale)
print(out.shape) # [batch, seq_len, num_head, head_size]
"""
if scale is None:
head_size = query.shape[3]
scale = float(1.0 / math.sqrt(head_size))
if in_dynamic_mode():
return _C_ops.variable_length_memory_efficient_attention(
query, key, value, seq_lens, kv_seq_lens, mask, scale, causal
)
helper = LayerHelper(
'variable_length_memory_efficient_attention', **locals()
)
out = helper.create_variable_for_type_inference(dtype=query.dtype)
helper.append_op(
type='variable_length_memory_efficient_attention',
inputs={
'query': query,
'key': key,
'value': value,
'seq_lens': seq_lens,
'kv_seq_lens': kv_seq_lens,
"mask": mask,
},
attrs={"scale": scale, "causal": causal},
outputs={'out': out},
)
return out
......@@ -158,6 +158,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_rms_norm_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_matmul_int8_op)
list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention)
endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.framework import core
from paddle.incubate.nn.functional import (
variable_length_memory_efficient_attention,
)
from paddle.static import Program, program_guard
paddle.seed(2023)
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def create_attn_mask(
mask_type,
batch_size,
seq_lens,
):
max_seq_len = max(seq_lens)
mask = paddle.zeros(
[batch_size, 1, max_seq_len, max_seq_len], dtype=mask_type
)
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type))
- 1
) * 1e4
return mask
def naive_attention_impl(query, key, value, mask, scale):
qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(softmax_result, value)
return result
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.2",
)
class TestMemEffAttentionVariableAPI(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariable_fp32"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.num_head = 8
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])
def test_all(self):
paddle.disable_static()
query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale)
out = variable_length_memory_efficient_attention(
q,
k,
v,
self.seq_lens,
self.seq_lens,
self.attention_mask,
self.scale,
)
for i in range(self.batch_size):
np.testing.assert_allclose(
out.numpy()[i, :, : self.seq_lens[i], :],
out_[i, :, : self.seq_lens[i], :],
rtol=5e-03,
atol=1e-03,
)
class TestMemEffAPIVariableDtypeFP16(TestMemEffAttentionVariableAPI):
def setUp(self):
self.name = "MemEffAPIVariable_fp16"
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])
class TestMemEffAPIVariableDtypeBF16(TestMemEffAttentionVariableAPI):
def setUp(self):
self.name = "MemEffAPIVariable_bf16"
self.place = paddle.CUDAPlace(0)
self.batch_size = 2
self.num_head = 8
self.seq_len = 32
self.dim_head = 128
self.seq_lens = paddle.to_tensor(
[
self.seq_len // 2,
self.seq_len,
],
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'bfloat16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
self.scale = 1.0 / np.sqrt(self.shape[-1])
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.2",
)
class TestMemEffAPIVariableDtypeFP16Static(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariableStatic_fp16"
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
).numpy()
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
).numpy()
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.ref_out = naive_attention_impl(
paddle.to_tensor(self.q),
paddle.to_tensor(self.k),
paddle.to_tensor(self.v),
paddle.to_tensor(self.attention_mask),
self.scale,
)
def test_all(self):
paddle.enable_static()
with program_guard(Program(), Program()):
q = paddle.static.data(
name="query", shape=self.shape, dtype=self.dtype
)
k = paddle.static.data(
name="key", shape=self.shape, dtype=self.dtype
)
v = paddle.static.data(
name="value", shape=self.shape, dtype=self.dtype
)
mask = paddle.static.data(
name="mask",
shape=[self.batch_size, 1, self.seq_len, self.seq_len],
dtype=self.dtype,
)
seq_lens = paddle.static.data(
name="seq_lens", shape=[self.batch_size, 1], dtype="int32"
)
out = variable_length_memory_efficient_attention(
q, k, v, seq_lens, seq_lens, mask, self.scale
)
exe = fluid.Executor()
res = exe.run(
feed={
"query": self.q,
"key": self.k,
"value": self.v,
"seq_lens": self.seq_lens,
"mask": self.attention_mask,
},
fetch_list=[out],
)
paddle.disable_static()
np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册