From 4036c9373c264332b7f23cdc72f6cfbb9492f8ee Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Thu, 10 Aug 2023 17:23:40 +0800 Subject: [PATCH] Add variable_length_memory_efficient_attention (#55400) * add variable_length_memory_efficient_attention * update variable_length_memory_efficient_attention unittest * update variable_length_mem_eff_attn's docs and unittest * update variable_length_mem_eff_attn's docs * Update test_variable_length_memory_efficient_attention.py * Update variable_length_memory_efficient_attention.cu * fix codestyle * fix variable_length_fmha's docs and unittest * fix variable_length_fmha's docs --- paddle/phi/api/yaml/ops.yaml | 10 + paddle/phi/infermeta/multiary.cc | 75 ++ paddle/phi/infermeta/multiary.h | 11 + paddle/phi/kernels/CMakeLists.txt | 13 +- .../memory_efficient_attention/.gitignore | 2 +- .../default_fmha_grouped.h | 301 ++++++ .../gemm/attention_scaling_coefs_updater.h | 517 +++++++++ .../gemm/base_grouped.h | 487 +++++++++ .../gemm/fmha_grouped.h | 979 ++++++++++++++++++ .../gemm/fmha_grouped_problem_visitor.h | 184 ++++ .../gemm/gemm_grouped.h | 62 ++ .../generate_variable_forward_kernels.py | 573 ++++++++++ ...iable_length_memory_efficient_attention.cu | 132 +++ .../paddle/incubate/nn/functional/__init__.py | 4 + ...iable_length_memory_efficient_attention.py | 116 +++ test/legacy_test/CMakeLists.txt | 1 + ...iable_length_memory_efficient_attention.py | 296 ++++++ 17 files changed, 3760 insertions(+), 3 deletions(-) create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/attention_scaling_coefs_updater.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/base_grouped.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped_problem_visitor.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h create mode 100644 paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py create mode 100644 paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu create mode 100644 python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py create mode 100644 test/legacy_test/test_variable_length_memory_efficient_attention.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c5bca9f4920..ecc29de613d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2682,6 +2682,16 @@ skip_transform : found_infinite inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps) +- op : variable_length_memory_efficient_attention + args : (Tensor query, Tensor key, Tensor value, Tensor seq_lens, Tensor kv_seq_lens, Tensor mask, float scale, bool causal) + output : Tensor + infer_meta : + func : VariableLengthMemoryEfficientAttentionInferMeta + kernel : + func : variable_length_memory_efficient_attention + data_type : query + optional : mask + - op : view_dtype args : (Tensor input, DataType dtype) output : Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 62628c781b3..ee84f6d169d 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2570,6 +2570,81 @@ void MemoryEfficientAttentionInferMeta(const MetaTensor& query, seed_and_offset->set_dtype(phi::DataType::INT64); } +void VariableLengthMemoryEfficientAttentionInferMeta( + const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& seq_lens, + const MetaTensor& kv_seq_lens, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out) { + PADDLE_ENFORCE_EQ( + query.dims().size(), + 4, + phi::errors::InvalidArgument("Query should be a 4-D tensor" + "But received Query dimension(%s)", + query.dims().size())); + PADDLE_ENFORCE_EQ( + key.dims().size(), + 4, + phi::errors::InvalidArgument("Key should be a 4-D tensor" + "But received Key dimension(%s)", + key.dims().size())); + PADDLE_ENFORCE_EQ( + value.dims().size(), + 4, + phi::errors::InvalidArgument("Value should be a 4-D tensor" + "But received Value dimension(%s)", + value.dims().size())); + + const int64_t query_batch_size = query.dims()[0]; + const int64_t query_num_head = query.dims()[1]; + const int64_t query_seq_length = query.dims()[2]; + const int64_t query_head_size = query.dims()[3]; + + const int64_t key_batch_size = key.dims()[0]; + const int64_t key_num_head = key.dims()[1]; + const int64_t key_seq_length = key.dims()[2]; + const int64_t key_head_size = key.dims()[3]; + + const int64_t value_batch_size = value.dims()[0]; + const int64_t value_num_head = value.dims()[1]; + const int64_t value_seq_length = value.dims()[2]; + const int64_t value_head_size = value.dims()[3]; + + PADDLE_ENFORCE_EQ( + ((query_batch_size == key_batch_size) && + (key_batch_size == value_batch_size)), + true, + phi::errors::InvalidArgument( + "The batch size of Query, Key, Value should be equal.")); + + PADDLE_ENFORCE_EQ( + ((query_num_head == key_num_head) && (key_num_head == value_num_head)), + true, + phi::errors::InvalidArgument( + "The head number of Query, Key, Value should be equal.")); + + PADDLE_ENFORCE_EQ(query_head_size == key_head_size, + true, + phi::errors::InvalidArgument( + "The head size of Query, Key should be equal.")); + + PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length, + true, + phi::errors::InvalidArgument( + "The seq length of Key, Value should be equal.")); + + std::vector out_dims( + {query_batch_size, query_num_head, query_seq_length, value_head_size}); + + out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype(query.dtype()); + out->set_layout(query.layout()); +} + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c524e8b9600..2d24b2252a5 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -471,6 +471,17 @@ void MemoryEfficientAttentionInferMeta(const MetaTensor& query, MetaTensor* logsumexp, MetaTensor* seed_and_offset); +void VariableLengthMemoryEfficientAttentionInferMeta( + const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& seq_lens, + const MetaTensor& kv_seq_lens, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out); + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs); diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 84df39c7eb6..9a917004c83 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -69,6 +69,13 @@ if(WITH_CUTLASS) --cuda_arch "${NVCC_ARCH_BIN}" RESULT_VARIABLE memory_efficient_attention_gen_res) + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py + --cuda_arch "${NVCC_ARCH_BIN}" + RESULT_VARIABLE memory_efficient_attention_gen_res) + if(NOT memory_efficient_attention_gen_res EQUAL 0) message( FATAL_ERROR @@ -79,9 +86,11 @@ if(WITH_CUTLASS) file( GLOB cutlass_cu RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "fusion/cutlass/conv2d/generated/*.cu" "fusion/cutlass/conv2d/*.cu" + "fusion/cutlass/conv2d/generated/*.cu" + "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu" - "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu") + "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu" + "fusion/cutlass/memory_efficient_attention/autogen_variable/impl/*.cu") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/.gitignore b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/.gitignore index 5b3f298a226..b7d72287e7c 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/.gitignore +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/.gitignore @@ -1 +1 @@ -autogen +autogen* diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h new file mode 100644 index 00000000000..5d8bfa32e64 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h @@ -0,0 +1,301 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "./gemm_kernel_utils.h" +#include "gemm/attention_scaling_coefs_updater.h" +#include "gemm/find_default_mma.h" +#include "gemm/fmha_grouped.h" +#include "gemm/mma_from_smem.h" +#include "transform/tile_smem_loader.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + bool maskIsAligned_, + int QueriesPerBlock_, + int KeysPerBlock_, + bool SingleValueIteration_, + GroupScheduleMode GroupScheduleMode_, + bool AddMask, + bool MaskBroadcastRow> +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static bool const kAddMask = AddMask; + static bool const kMaskBroadcastRow = MaskBroadcastRow; + static bool const kSingleValueIteration = SingleValueIteration_; + static int const kKeysPerBlock = KeysPerBlock_; + static bool const kMaskIsAligned = maskIsAligned_; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = + QueriesPerBlock_ * KeysPerBlock_ / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator>; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator>::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Updater; + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // used for efficient load of mask_ tile Bij from global to shared memory + using MaskLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + kMaskIsAligned ? 128 / cutlass::sizeof_bits::value : 1>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator>; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage, + false>; + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + /// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/attention_scaling_coefs_updater.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/attention_scaling_coefs_updater.h new file mode 100644 index 00000000000..e1d11720595 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/attention_scaling_coefs_updater.h @@ -0,0 +1,517 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../gemm_kernel_utils.h" +#include "../kernel_forward.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* Iterates on the accumulator and corresponding position on result matrix + +(1) Update `mi[r]` to the max value of the row `r` +(2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + +All of this is done on registers, before we store all of this +on shared memory for the next matmul with Value. + +We have multiple implementations, because each configuration has a different way +of iterating in the accumulators. +*/ + +template +struct RegisterOps { + template + CUTLASS_DEVICE static void update( + typename T::Fragment& frag_o, // NOLINT + typename T::Fragment& frag, // NOLINT + cutlass::Array& mi, // NOLINT + cutlass::Array& m_prime, // NOLINT + cutlass::Array& s_prime, // NOLINT + cutlass::Array& + addition_storage, // NOLINT + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int16_t max_col, + typename T::TensorCoord const& tile_offset, + float scaling) { + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + if (!kIsFirst) { + if (thread_id < kQueriesPerBlock) { + m_prime[thread_id] = mi[thread_id]; + } + __syncthreads(); + } + + auto lane_offset = BASE::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (kFullColumns || accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max * scaling); + }); + } + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + // Make sure we all share the update values for `mi` + __syncthreads(); + + if (thread_id < kQueriesPerBlock) { + auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); + m_prime[thread_id] = m_prime_exp; + s_prime[thread_id] *= m_prime_exp; + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !kIsFirst) { + accum_t mp; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { mp = m_prime[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, + [&](int accum_m) {}); + __syncthreads(); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + BASE::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = (kFullColumns || accum_n < max_col) + ? exp2f(frag[idx] - mi_row) + : accum_t(0.0); + }, + [&](int accum_m) {}); + BASE::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (BASE::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + addition_storage[accum_m + kQueriesPerBlock * + tile_offset.column()] = total_row; + // atomicAdd(&s_prime[accum_m], total_row); + } + }); + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + total_row = s_prime[id]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kWarpN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + } +}; + +template +struct AttentionScalingCoefsUpdaterSm80 + : RegisterOps, + T, + accum_t, + kWarpSize> { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, // NOLINT + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, + DT& myValue, // NOLINT + F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AttentionScalingCoefsUpdaterVolta + : RegisterOps, + T, + accum_t, + kWarpSize> { + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, + DT& myValue, // NOLINT + F fn) { + static_assert(cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, // NOLINT + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AttentionScalingCoefsUpdaterSimt + : RegisterOps, + T, + accum_t, + kWarpSize> { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert(cutlass::platform::is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, + DT& myValue, // NOLINT + F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, // NOLINT + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = + n + Policy::LaneMmaShape::kN * + (mma_n + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE + get_lane_offset(int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert(cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = + lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultAttentionScalingCoefsUpdater; + +// Simt +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaSimtTileIterator, + accum_t, + kWarpSize> { + using Iterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Updater = + AttentionScalingCoefsUpdaterSimt; +}; + +// TensorOp - Volta +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Updater = + AttentionScalingCoefsUpdaterVolta; +}; + +// TensorOp - Sm75+ +template +struct DefaultAttentionScalingCoefsUpdater< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using Iterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Updater = + AttentionScalingCoefsUpdaterSm80; +}; diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/base_grouped.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/base_grouped.h new file mode 100644 index 00000000000..94f251ea8a5 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/base_grouped.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Base device-level grouped kernel. +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class BaseGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = + typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count( + const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = + cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, + int32_t tile_count, + void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute( + args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + reinterpret_cast(host_workspace.data())); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the + // original. + std::vector copy(indices.size()); + for (int i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + return BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } else { + return 0; + } + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + int smem_size = + static_cast(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), + indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > + problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, + int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + // printf("custom base\n"); + static cudaDeviceProp properties; + static bool count = true; + if (count) { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; + } + + result = cudaGetDeviceProperties(&properties, device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error " + << cudaGetErrorString(result)); + return 0; + } + } + count = false; + + bool override_sm_count = + (available_sm_count < 0 || + available_sm_count > properties.multiProcessorCount); + if (override_sm_count) { + available_sm_count = properties.multiProcessorCount; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return + // total_tiles unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating + // through problem sizes to determine that they have no work to do. This + // competes for cycles with those threadblocks that are assigned tiles to + // compute. + return min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = + static_cast(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_.update(args, workspace, tile_count); + } else { + params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + // + // Configure grid and block dimensions + // + + if (!params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + dim3 grid(params_.threadblock_count, 1, 1); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = + static_cast(sizeof(typename BaseKernel::SharedStorage)); + + // + // Launch kernel + // + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, + void* workspace, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h new file mode 100644 index 00000000000..1a0d0be7090 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped.h @@ -0,0 +1,979 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "../epilogue/epilogue_rescale_output.h" +#include "../gemm_kernel_utils.h" +#include "./attention_scaling_coefs_updater.h" +#include "./fmha_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FMHAGrouped { + public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = + !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementM = scalar_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutM = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = + (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + static int64_t const kAlignmentM = kMaskIsAligned ? kAlignmentQ : 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + using ProblemVisitor = FMHAGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord *problem_sizes0; + GemmCoord *problem_sizes1; + + int problem_count; + int threadblock_count; + int num_heads; + + ElementQ *ptr_Q; + ElementK *ptr_K; + ElementP *ptr_P; + ElementM *ptr_M; + ElementV *ptr_V; + ElementO *ptr_O; + ElementOAccum *ptr_O_accum; + + typename LayoutQ::Stride::LongIndex ldq; + typename LayoutK::Stride::LongIndex ldk; + typename LayoutK::Stride::LongIndex ldm; + typename LayoutP::Stride::LongIndex ldv; + typename LayoutO::Stride::LongIndex ldo; + + int64_t kElementQ; + int64_t kElementK; + int64_t kElementM; + int64_t kElementV; + int64_t kElementO; + + // Scale + ElementAccumulator scale; + + // Whether causal masking is to be performed + bool causal; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + num_heads(0), + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_M(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(0), + ldk(0), + ldm(0), + ldv(0), + ldo(0), + scale(0), + kElementQ(0), + kElementK(0), + kElementM(0), + kElementV(0), + kElementO(0), + causal(false), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + int num_heads, + ElementQ *ptr_Q, + ElementK *ptr_K, + ElementM *ptr_M, + ElementV *ptr_V, + ElementO *ptr_O, + ElementOAccum *ptr_O_accum, + typename LayoutQ::Stride::LongIndex ldq, + typename LayoutK::Stride::LongIndex ldk, + typename LayoutM::Stride::LongIndex ldm, + typename LayoutV::Stride::LongIndex ldv, + typename LayoutO::Stride::LongIndex ldo, + int64_t kElementQ, + int64_t kElementK, + int64_t kElementM, + int64_t kElementV, + int64_t kElementO, + bool causal, + ElementAccumulator scale, + GemmCoord *host_problem_sizes = nullptr) + : problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + num_heads(num_heads), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_M(ptr_M), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer + ? ptr_O_accum + : reinterpret_cast(ptr_O)), // wip + ldq(ldq), + ldk(ldk), + ldm(ldm), + ldv(ldv), + ldo(ldo), + kElementQ(kElementQ), + kElementK(kElementK), + kElementM(kElementM), + kElementV(kElementV), + kElementO(kElementO), + causal(causal), + scale(scale), + host_problem_sizes(host_problem_sizes) {} + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + if (ptr_M != nullptr) { + CHECK_ALIGNED_PTR(ptr_M, kAlignmentM); + XFORMERS_CHECK(ldm % kAlignmentM == 0, + "attn_mask is not correctly aligned"); + } + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int num_heads; + + ElementQ *ptr_Q; + ElementK *ptr_K; + ElementP *ptr_P; + ElementM *ptr_M; + ElementV *ptr_V; + ElementO *ptr_O; + ElementOAccum *ptr_O_accum; + + typename LayoutQ::Stride::LongIndex ldq; + typename LayoutK::Stride::LongIndex ldk; + typename LayoutM::Stride::LongIndex ldm; + typename LayoutP::Stride::LongIndex ldv; + typename LayoutO::Stride::LongIndex ldo; + + int64_t kElementQ; + int64_t kElementK; + int64_t kElementM; + int64_t kElementV; + int64_t kElementO; + + ElementAccumulator scale; + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_M(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(0), + ldk(0), + ldm(0), + ldv(0), + ldo(0), + kElementQ(0), + kElementK(0), + kElementM(0), + kElementV(0), + kElementO(0), + causal(false), + scale(0) {} + + explicit CUTLASS_HOST_DEVICE Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) + : problem_visitor(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, + tile_count), + threadblock_count(args.threadblock_count), + num_heads(args.num_heads), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_M(args.ptr_M), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer + ? args.ptr_O_accum + : reinterpret_cast(args.ptr_O)), + ldq(args.ldq), + ldk(args.ldk), + ldm(args.ldm), + ldv(args.ldv), + ldo(args.ldo), + kElementQ(args.kElementQ), + kElementK(args.kElementK), + kElementM(args.kElementM), + kElementV(args.kElementV), + kElementO(args.kElementO), + causal(args.causal), + scale(args.scale) {} + + // CUTLASS_HOST_DEVICE + void update(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, + tile_count); + threadblock_count = args.threadblock_count; + num_heads = args.num_heads; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_M = args.ptr_M; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer + ? args.ptr_O_accum + : reinterpret_cast(args.ptr_O); + ldq = args.ldq; + ldk = args.ldk; + ldm = args.ldm; + ldv = args.ldv; + ldo = args.ldo; + kElementQ = args.kElementQ; + kElementK = args.kElementK; + kElementM = args.kElementM; + kElementV = args.kElementV; + kElementO = args.kElementO; + causal = args.causal; + scale = args.scale; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage & + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::MaskLoader::SmemTile mask; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage & + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + private: + // Parameters to be used by an individual tile + struct TileParams { + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, + const GemmCoord &problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, + const GemmCoord &problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, + const GemmCoord &problem_size0, + bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min( + int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x; } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, + SharedStorage &shared_storage) { // NOLINT + auto &m_prime = shared_storage.m_prime; + auto &s_prime = shared_storage.s_prime; + [[maybe_unused]] auto &si = shared_storage.after_mm0.si; + auto &mi = shared_storage.mi; + + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = + int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + const int32_t batch_idx = problem_idx / params.num_heads; + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O + problem_idx * params.kElementQ + + TileParams::query_start(threadblock_idx) * params.ldo; + ElementOAccum *ptr_O_accum = + params.ptr_O_accum + problem_idx * params.kElementO + + TileParams::query_start(threadblock_idx) * params.ldo; + const int num_queries = + TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo}, + ptr_O, + typename OutputTileIterator::TensorCoord{num_queries, + problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = + [&](int col) -> typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{num_queries, + problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = + TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min((int32_t)kKeysPerBlock, + num_keys - iter_key_start); + int32_t const &problem_size_0_k = problem_size0.k(); + int32_t const &problem_size_1_n = problem_size1.n(); + int32_t const &problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)}, + params.ptr_V + problem_idx * params.kElementV + + iter_key_start * params.ldv, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue(shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and + // `m_prime` updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q + problem_idx * params.kElementQ + + TileParams::query_start(threadblock_idx) * params.ldq; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq)), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk)), + params.ptr_K + problem_idx * params.kElementK + + iter_key_start * params.ldk, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = {(warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM)}; + + // apply attention mask if applicable + if (kAddMask) { + accum = cutlass::multiplies()( + params.scale, accum); + // load mask tile Bij into shared memory + typename MM0::MaskLoader::GmemTileIterator mask_iter( + {cutlass::layout::RowMajor(params.ldm)}, + // attn_mask_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + params.ptr_M + batch_idx * params.kElementM + + TileParams::query_start(threadblock_idx) * params.ldm + + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef + mask_tensor_ref( + shared_storage.after_mm0.mask.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::MaskLoader::SmemTileIterator smem_tile_iter( + mask_tensor_ref, thread_id()); + MM0::MaskLoader::load(mask_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += mask_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also updates `accum` with accum[i] <- + // exp(accum[i] * scale + // - mi) + MM0::ScalingCoefsUpdater::update< + kQueriesPerBlock, + MM0::MmaCore::WarpCount::kCount, + MM0::MmaCore::WarpCount::kN, + kFullColumns, + kIsFirst, + kKeepOutputInRF>( + accum_o, + accum, + mi, + m_prime, + s_prime, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iteratorC_tile_offset, + kAddMask ? 1.0f : params.scale); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % (MM0::Mma::Base::WarpCount::kM * + MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = + cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = + kKeepOutputInRF ? 1 + : ceil_div((int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = (problem_size_1_k + MM1::Mma::Shape::kK - 1) / + MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in + // accum (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)}, + params.ptr_V + problem_idx * params.kElementV + + iter_key_start * params.ldv, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + static_cast(thread_id()), + static_cast(warp_id()), + static_cast(lane_id()), + static_cast(problem_size_1_k)); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = + typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue:: + threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>:: + type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1:: + OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + const bool kIsFirst = true; + const bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue::thread:: + MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue(shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped_problem_visitor.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped_problem_visitor.h new file mode 100644 index 00000000000..7fd8047c9f5 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/fmha_grouped_problem_visitor.h @@ -0,0 +1,184 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape( + const cutlass::gemm::GemmCoord &problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + 1, + 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem( + cutlass::gemm::GemmCoord &problem) { // NOLINT + } // NOLINT + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord &grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor + : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + using ProblemSizeHelper = + detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : problem_sizes0(nullptr), + problem_sizes1(nullptr), + problem_count(0), + workspace(nullptr), + tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0) + : problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams( // Set problem_sizes as problem_sizes1 because these + // determine shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor(Params const ¶ms_, + SharedStorage &shared_storage_, // NOLINT + int32_t block_idx) + : Base(params_.to_base(), shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h new file mode 100644 index 00000000000..18398c1db2f --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Device-level grouped GEMM. +*/ + +#pragma once + +#include "./base_grouped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class GemmGrouped : public BaseGrouped { + public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py new file mode 100644 index 00000000000..7d7c12ce61a --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py @@ -0,0 +1,573 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Generates combination of kernels - implementations and registry + +# Kernels are ordered (see `sort_index`), and when dispatching, +# we select the first kernel in the list that supports the inputs + +import argparse +import collections +import itertools +import os +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TypeVar + +DEFAULT_ARCH = [50, 70, 75, 80] +MAX_ARCH = 90 +ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION" + +assert sorted(DEFAULT_ARCH) == DEFAULT_ARCH + + +def find_arch_range(min_arch, max_arch): + assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH + assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH + assert min_arch <= max_arch + n = len(DEFAULT_ARCH) + + start_idx = n - 1 + for i in range(n - 1): + if DEFAULT_ARCH[i] <= min_arch and min_arch < DEFAULT_ARCH[i + 1]: + start_idx = i + break + + end_idx = n + for i in range(n - 1): + if DEFAULT_ARCH[i] <= max_arch and max_arch < DEFAULT_ARCH[i + 1]: + end_idx = i + 1 + + return DEFAULT_ARCH[start_idx:end_idx] + + +def find_max_arch(arch): + arch = sorted(arch) + idx = DEFAULT_ARCH.index(arch[-1]) + if idx == len(DEFAULT_ARCH) - 1: + return MAX_ARCH + else: + return DEFAULT_ARCH[idx + 1] + + +def convert_to_arch_list(arch): + arch = arch.lower().strip() + if arch == "all": + return DEFAULT_ARCH + + arch = [int(s.strip()) for s in arch.split(';') if s.strip()] + arch = list(set(arch)) + arch.sort() + return find_arch_range(arch[0], arch[-1]) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The argument for generating the memory efficient kernels." + ) + parser.add_argument( + "--dst_path", + type=str, + default=str(Path(__file__).parent), + help="The destination path to save the generated files.", + ) + parser.add_argument( + "--cuda_arch", + type=convert_to_arch_list, + default=convert_to_arch_list("All"), + help="The CUDA architecture to be generated.", + ) + args = parser.parse_args() + args.max_arch = find_max_arch(args.cuda_arch) + return args + + +args = parse_args() + +DTYPES = { + "f32": "float", + "f16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = args.cuda_arch + +KERNEL_IMPL_TEMPLATE = """ + +void {NAME}({CPP_CLASS} default_fmha, Params ¶ms, const phi::GPUContext& ctx) {{ + using AttentionKernel = typename decltype(default_fmha)::FMHAKernel; + using FMHA = cutlass::gemm::device::GemmGrouped; + using scalar_t = typename FMHA::GemmKernel::scalar_t; + using accum_t = typename FMHA::GemmKernel::accum_t; + using output_t = typename FMHA::GemmKernel::output_t; + using output_accum_t = typename FMHA::GemmKernel::output_accum_t; + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementM = scalar_t; + using ElementAccumulator = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + + int problem_count = params.num_batches * params.num_heads; + + std::vector problem_sizes1; + problem_sizes1.reserve(problem_count); + + phi::Allocator::AllocationPtr problem_sizes_device0{{nullptr}}; + phi::Allocator::AllocationPtr problem_sizes_device1{{nullptr}}; + problem_sizes_device0 = phi::memory_utils::Alloc( + ctx.GetPlace(), + problem_count * sizeof(GemmCoord), + phi::Stream(reinterpret_cast(ctx.stream()))); + problem_sizes_device1 = phi::memory_utils::Alloc( + ctx.GetPlace(), + problem_count * sizeof(GemmCoord), + phi::Stream(reinterpret_cast(ctx.stream()))); + GemmCoord* problem0_device = + reinterpret_cast(problem_sizes_device0->ptr()); + GemmCoord* problem1_device = + reinterpret_cast(problem_sizes_device1->ptr()); + get_problem_sizes<<>>( + params.seq_lens, + params.kv_seq_lens, + problem0_device, + problem1_device, + params.num_batches, + params.num_heads, + params.head_size, + params.value_head_size); + phi::memory_utils::Copy(phi::CPUPlace(), + problem_sizes1.data(), + ctx.GetPlace(), + problem1_device, + sizeof(GemmCoord) * problem_count, + ctx.stream()); + if (AttentionKernel::kNeedsOutputAccumulatorBuffer) {{ + const int64_t output_size = params.num_batches * params.num_heads * + params.query_seq_len * params.value_head_size; + phi::Allocator::AllocationPtr tmp_output_accum_buffer_ptr{{nullptr}}; + tmp_output_accum_buffer_ptr = phi::memory_utils::Alloc( + ctx.GetPlace(), + output_size * sizeof(ElementOAccum), + phi::Stream(reinterpret_cast(ctx.stream()))); + params.output_accum_ptr = tmp_output_accum_buffer_ptr->ptr(); + }} + int threadblock_count = + FMHA::sufficient(problem_sizes1.data(), problem_count); + typename FMHA::Arguments args( + problem0_device, + problem1_device, + problem_count, + threadblock_count, + params.num_heads, + const_cast(reinterpret_cast(params.query_ptr)), + const_cast(reinterpret_cast(params.key_ptr)), + params.mask_ptr + ? const_cast(reinterpret_cast(params.mask_ptr)) + : nullptr, + const_cast(reinterpret_cast(params.value_ptr)), + reinterpret_cast(params.output_ptr), + AttentionKernel::kNeedsOutputAccumulatorBuffer + ? reinterpret_cast(params.output_accum_ptr) + : nullptr, + params.ldq, + params.ldk, + params.ldm, + params.ldv, + params.ldo, + params.ElementQ, + params.ElementK, + params.ElementM, + params.ElementV, + params.ElementO, + params.causal, + params.scale, + problem_sizes1.data()); + + FMHA fmha; + cutlass::Status status; + size_t workspace_size = fmha.get_workspace_size(args); + phi::DenseTensor workspace; + workspace.Resize(phi::make_ddim({{static_cast(workspace_size)}})); + ctx.template Alloc(&workspace); + status = fmha.initialize(args, workspace.data()); + if (status != cutlass::Status::kSuccess) {{ + PADDLE_THROW(phi::errors::Unimplemented( + "Failed to initialize CUTLASS Grouped FMHA kernel.")); + }} + status = fmha.run(ctx.stream()); + if (status != cutlass::Status::kSuccess) {{ + PADDLE_THROW(phi::errors::Unimplemented( + "Failed to run CUTLASS Grouped FMHA kernel.")); + }} +}} +""" + + +@dataclass(order=True) +class FwdKernel: + sort_index: Tuple[int, ...] = field(init=False, repr=False) + aligned: bool + mask_aligned: bool + dtype: str + sm_range: Tuple[int, int] + q: int + k: int + single_value_iter: bool + support_mask: bool = True + mask_broadcast: bool = False + dispatch_cond: Optional[str] = None + + def __post_init__(self) -> None: + # Set kernel selection priority + # The lowest value that matches inputs + # will be selected + self.sort_index = ( + # First select aligned kernel + 0 if self.aligned else 1, + 0 if self.support_mask else 1, + # Then keep output in RF + 0 if self.single_value_iter else 1, + self.q, + 0 if self.mask_aligned else 1, + 0 if self.mask_broadcast else 1, + ) + + @property + def _aligned_suffix(self) -> str: + return "aligned" if self.aligned else "notaligned" + + @property + def _mask_aligned_suffix(self) -> str: + return "ma" if self.mask_aligned else "mua" + + @property + def _mask_support_suffix(self) -> str: + return "sm" if self.support_mask else "usm" + + @property + def _mask_broadcast_suffix(self) -> str: + return "mb" if self.mask_broadcast else "mnb" + + @property + def _single_value_suffix(self) -> str: + return "rf" if self.single_value_iter else "urf" + + @property + def name(self) -> str: + return f"fmha_cutlassF_variable_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{self._single_value_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_sm{self.sm_range[0]}" + + @property + def cpp_class(self) -> str: + template_args = ", ".join( + [ + DTYPES[self.dtype], + f"cutlass::arch::Sm{self.sm_range[0]}", + "true" if self.aligned else "false", + "true" if self.mask_aligned else "false", + str(self.q), + str(self.k), + "true" if self.single_value_iter else "false", + "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", + "true" if self.support_mask else "false", + "false", + ] + ) + return f"cutlass::gemm::kernel::DefaultFMHAGrouped<{template_args}>" + + @property + def impl_group(self) -> str: + # Maps to file which will contain the implementation + return f"{self.dtype}_{self._aligned_suffix}_{self._mask_support_suffix}_{self._mask_aligned_suffix}_{self._mask_broadcast_suffix}_{self._single_value_suffix}_{self.q}x{self.k}" + + @property + def cpp_impl(self) -> str: + return KERNEL_IMPL_TEMPLATE.format( + CPP_CLASS=self.cpp_class, NAME=self.name + ) + + @classmethod + def get_all(cls) -> List["FwdKernel"]: + kernels: List[FwdKernel] = [] + for aligned, dtype, (sm, sm_max) in itertools.product( + [True, False], DTYPES.keys(), zip(SM, SM[1:] + [args.max_arch]) + ): + # Remove some kernels we don't use + if dtype == "bf16" and sm < 80: + continue + if not aligned and sm >= 80: + continue + for q, k, single_value_iter in [ + (32, 128, True), + (32, 128, False), + (64, 64, True), + ]: + for support_mask, mask_aligned in [ + (False, False), + (True, False), + (True, True), + ]: + kernels.append( + cls( + aligned=aligned, + dtype=dtype, + sm_range=(sm, sm_max), + q=q, + k=k, + single_value_iter=single_value_iter, + support_mask=support_mask, + mask_aligned=mask_aligned, + mask_broadcast=False, + ) + ) + return kernels + + +T = TypeVar("T", bound=FwdKernel) + + +def write_decl_impl( + kernels: List[T], family_name: str, impl_file: str, enable_def: str +) -> None: + cpp_file_header = """// This file is auto-generated. See "generate_variable_forward_kernels.py" +""" + + kernels.sort() + + implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list) + cat_to_kernels: Dict[ + Tuple[str, int, int], List[T] + ] = collections.defaultdict(list) + + dispatch_all = "" + declarations = cpp_file_header + "#pragma once\n" + declarations += f"#ifdef {enable_def}\n" + declarations += f"""#include "{impl_file}"\n""" + declarations += "namespace phi {\n" + + # Declaration of kernel functions + for k in kernels: + implfile_to_kernels[k.impl_group].append(k) + cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k) + + for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items(): + declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n" + declarations += "\n".join( + k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels + ) + dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}" + declarations += ( + f"\n\ntemplate void {dispatch_category_fn}(T cb) {{\n" + ) + for k in kernels: + _call = f"cb({k.cpp_class}(), {k.name});\n" + if k.dispatch_cond is not None: + _call = f"if ({k.dispatch_cond}) {_call}" + declarations += f" {_call}" + declarations += "}\n\n" + dispatch_all += f""" + if (std::is_same::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{ + {dispatch_category_fn}(cb); + }}""" + + declarations += f""" +template +void dispatch_{family_name}(const ::phi::GPUContext &ctx, T cb) {{ + auto cc = ctx.GetComputeCapability(); + PADDLE_ENFORCE_GE( + cc, + 70, + phi::errors::InvalidArgument("the Nvidia GPU's Compute Capability must be greater or equal than 70")); + + using DT = typename ::phi::CutlassTrait::Type; +{dispatch_all} +}} +""" + declarations += "} // namespace phi\n" + declarations += f"#endif // {enable_def}\n" + + autogen_dir = Path(args.dst_path) / "autogen_variable" + os.makedirs(autogen_dir, exist_ok=True) + declaration_path = autogen_dir / f"{family_name}.h" + declaration_path.write_text(declarations) + + for f, f_kernels in implfile_to_kernels.items(): + impl_cu = cpp_file_header + impl_cu += f"#ifdef {enable_def}\n" + impl_cu += f"""#include "{impl_file}"\n""" + impl_cu += "namespace phi {\n" + for k in f_kernels: + impl_cu += k.cpp_impl + impl_cu += "} // namespace phi\n" + impl_cu += f"#endif // {enable_def}\n" + impl_path = autogen_dir / "impl" + os.makedirs(impl_path, exist_ok=True) + (impl_path / f"{family_name}_{f}.cu").write_text(impl_cu) + + +def write_main_header(): + main_header_content = ''' +#pragma once + +#ifdef {} + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "cutlass/util/device_memory.h" +#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/default_fmha_grouped.h" +#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/gemm/gemm_grouped.h" + +namespace phi {{ + +using GemmCoord = cutlass::gemm::GemmCoord; + +struct Params {{ + // meta params + phi::DataType datatype; + + // [bs, nh, seq_len, dh] + const void* query_ptr; + const void* key_ptr; + const void* value_ptr; + + // and it can be broadcasted in axis0, 1, 2. + const void* mask_ptr = nullptr; + + const int* seq_lens = nullptr; + const int* kv_seq_lens = nullptr; + + // Output tensors + void* output_ptr; // [num_batches, num_heads, query_seq_len, head_size] + void* output_accum_ptr = + nullptr; // [num_batches, num_heads, query_seq_len, head_size] + + // Scale + float scale; + + // Dimensions/strides + int32_t num_batches; + int32_t num_heads; + int32_t query_seq_len; + int32_t key_value_seq_len; + int32_t head_size; + int32_t value_head_size; + + int64_t ldq; + int64_t ldk; + int64_t ldm; + int64_t ldv; + int64_t ldo; + + int64_t ElementQ; + int64_t ElementK; + int64_t ElementM; + int64_t ElementV; + int64_t ElementO; + + bool causal; + bool mask_broadcast_row; +}}; + +__global__ static void get_problem_sizes(const int* seq_lens, + const int* kv_seq_lens, + GemmCoord* problem_sizes0, + GemmCoord* problem_sizes1, + const int bs, + const int num_head, + const int head_size, + const int value_head_size) {{ + int bi = blockIdx.x; + int hi = threadIdx.x; + if (bi < bs && hi < num_head) {{ + int id = bi * num_head + hi; + int m = seq_lens[bi]; + int mkv = kv_seq_lens[bi]; + int k0 = head_size; + int k1 = value_head_size; + GemmCoord problem0(m, mkv, k0); + GemmCoord problem1(m, k1, mkv); + problem_sizes0[id] = problem0; + problem_sizes1[id] = problem1; + }} +}} + +template +struct CutlassTrait {{ + using Type = T; +}}; + +template <> +struct CutlassTrait {{ + using Type = cutlass::half_t; +}}; + +template <> +struct CutlassTrait {{ + using Type = cutlass::bfloat16_t; +}}; + + +template +struct ToPhiDTypeTrait {{ + private: + using NonConstT = typename std::remove_const::type; + static constexpr bool kIsFP16 = std::is_same::value; + static constexpr bool kIsBF16 = std::is_same::value; + + public: + using Type = typename std::conditional::type>::type; +}}; + +}} // namespace phi + +#include "./cutlass_forward.h" + +#endif +'''.format( + ENABLE_MACRO + ) + + path = Path(args.dst_path) / "autogen_variable" + os.makedirs(path, exist_ok=True) + path = Path(path) / "memory_efficient_variable_attention.h" + path.write_text(main_header_content) + + +if os.path.exists(Path(args.dst_path) / "autogen_variable"): + shutil.rmtree(Path(args.dst_path) / "autogen_variable") +forward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen_variable/memory_efficient_variable_attention.h" + +write_main_header() + +write_decl_impl( + FwdKernel.get_all(), + "cutlass_forward", + impl_file=forward_impl, + enable_def=ENABLE_MACRO, +) diff --git a/paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu b/paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu new file mode 100644 index 00000000000..f8784d1966f --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.cu @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen_variable/cutlass_forward.h" + +namespace phi { +namespace fusion { + +template +void MultiHeadAttentionVariableForwardKernel( + const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const DenseTensor& seq_lens, + const DenseTensor& kv_seq_lens, + const paddle::optional& mask, + const float scale, + const bool causal, + DenseTensor* output) { + ctx.template Alloc(output); + Params params{}; + // [B, N, S, H] + params.seq_lens = seq_lens.data(); + params.kv_seq_lens = kv_seq_lens.data(); + + params.num_batches = query.dims()[0]; + params.num_heads = query.dims()[1]; + params.query_seq_len = query.dims()[2]; + params.head_size = query.dims()[3]; + params.key_value_seq_len = key.dims()[2]; + params.value_head_size = value.dims()[3]; + + params.datatype = query.dtype(); + params.query_ptr = query.data(); + params.key_ptr = key.data(); + params.value_ptr = value.data(); + + params.output_ptr = output->data(); + + params.ldq = params.head_size; + params.ldk = params.head_size; + params.ldv = params.value_head_size; + params.ldo = params.value_head_size; + + params.ElementQ = params.query_seq_len * params.head_size; + params.ElementK = params.key_value_seq_len * params.head_size; + params.ElementV = params.key_value_seq_len * params.value_head_size; + params.ElementO = params.query_seq_len * params.value_head_size; + + params.scale = scale; + params.causal = causal; + + if (mask) { + // [B, 1, S, D] + auto mask_tensor = mask.get(); + params.ldm = mask_tensor.dims()[3]; + params.ElementM = mask_tensor.dims()[2] * mask_tensor.dims()[3]; + params.mask_ptr = mask_tensor.data(); + params.mask_broadcast_row = false; + } + + bool kernel_launched = false; + + auto launchKernel = [&](auto k_, auto kernel_fn) { + using KernelType = decltype(k_); + if (kernel_launched) { + return; + } + if (mask && !KernelType::kAddMask) { + return; + } + if (!mask && KernelType::kAddMask) { + return; + } + if (KernelType::kMaskBroadcastRow) { + // not support mask_broad_cast + return; + } + if (mask && reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0 && !KernelType::kMaskIsAligned) { + return; + } + if (mask && + !(reinterpret_cast(params.mask_ptr) % 16 == 0 && + params.ldm % (16 / sizeof(T)) == 0) && + KernelType::kMaskIsAligned) { + return; + } + if (KernelType::kSingleValueIteration && + KernelType::kKeysPerBlock < params.value_head_size) { + return; + } + if (KernelType::kKeysPerBlock == 64 && params.value_head_size > 64) { + return; + } + if (params.head_size % KernelType::MM0::kAlignmentA) { + return; + } + kernel_launched = true; + kernel_fn(k_, params, ctx); + }; + dispatch_cutlass_forward(ctx, launchKernel); + PADDLE_ENFORCE_EQ( + kernel_launched, + true, + phi::errors::InvalidArgument("the kernel should not be launched")); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(variable_length_memory_efficient_attention, + GPU, + ALL_LAYOUT, + phi::fusion::MultiHeadAttentionVariableForwardKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetDataType(phi::DataType::INT32); +} diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 207a4fcb036..81b003fe8d1 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -25,6 +25,9 @@ from .fused_ec_moe import fused_ec_moe from .fused_dropout_add import fused_dropout_add from .fused_gate_attention import fused_gate_attention from .fused_rotary_position_embedding import fused_rotary_position_embedding +from .variable_length_memory_efficient_attention import ( + variable_length_memory_efficient_attention, +) from .rms_norm import rms_norm __all__ = [ @@ -38,5 +41,6 @@ __all__ = [ 'fused_ec_moe', 'fused_dropout_add', 'fused_rotary_position_embedding', + 'variable_length_memory_efficient_attention', "rms_norm", ] diff --git a/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py b/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py new file mode 100644 index 00000000000..06f9772628a --- /dev/null +++ b/python/paddle/incubate/nn/functional/variable_length_memory_efficient_attention.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The following codes are from https://github.com/facebookresearch/xformers + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from paddle import _C_ops +from paddle.framework import LayerHelper, in_dynamic_mode + + +def variable_length_memory_efficient_attention( + query, + key, + value, + seq_lens, + kv_seq_lens, + mask=None, + scale=None, + causal=False, +): + """ + Cutlass Memory Efficient Variable Attention. + This method requires SM_ARCH in sm70, sm75, sm80. + + Args: + query (Tensor): The Query Tensor. Its shape is [batchsize, seq_len, num_head, head_size]. + key (Tensor): The Key Tensor. Its shape is [batchsize, seq_len, num_head, head_size]. + value (Tensor): The Value Tensor. Its shape is [batchsize, seq_len, num_head, head_size]. + seq_lens (Tensor): The sequence lengths of the sequences in the batch, used to index query. Its shape is [batchsize, 1]. + kv_seq_lens (Tensor): The sequence lengths of the sequences in the batch, used to index key and value. Its shape is [batchsize, 1]. + mask (Tensor): The Mask Tensor. Its shape is [batchsize, 1, query_seq_len, key_seq_len]. + scale (Float): The attention matrix's scale. Default is sqrt(1.0 / head_size). + causal (Bool): Whether causal masking is used or not. Default is False. + Returns: + Tensor: the output Tensor. + + Examples: + .. code-block:: python + + # required: gpu + import math + import paddle + from paddle.incubate.nn.functional import variable_length_memory_efficient_attention + + batch = 1 + num_head = 8 + seq_len = 256 + head_size = 32 + + dtype = paddle.float16 + + query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype) + seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32') + mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype) + + scale = float(1.0 / math.sqrt(head_size)) + + def naive_attention_impl(query, key, value, mask, scale): + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(softmax_result, value) + return result + + out = naive_attention_impl(query, key, value, mask, scale) + # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale) + + print(out.shape) # [batch, seq_len, num_head, head_size] + """ + if scale is None: + head_size = query.shape[3] + scale = float(1.0 / math.sqrt(head_size)) + + if in_dynamic_mode(): + return _C_ops.variable_length_memory_efficient_attention( + query, key, value, seq_lens, kv_seq_lens, mask, scale, causal + ) + + helper = LayerHelper( + 'variable_length_memory_efficient_attention', **locals() + ) + out = helper.create_variable_for_type_inference(dtype=query.dtype) + helper.append_op( + type='variable_length_memory_efficient_attention', + inputs={ + 'query': query, + 'key': key, + 'value': value, + 'seq_lens': seq_lens, + 'kv_seq_lens': kv_seq_lens, + "mask": mask, + }, + attrs={"scale": scale, "causal": causal}, + outputs={'out': out}, + ) + return out diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index de79c7b35fa..0ec79246044 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -158,6 +158,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_matmul_int8_op) + list(REMOVE_ITEM TEST_OPS test_variable_length_memory_efficient_attention) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) diff --git a/test/legacy_test/test_variable_length_memory_efficient_attention.py b/test/legacy_test/test_variable_length_memory_efficient_attention.py new file mode 100644 index 00000000000..76630a8e6c2 --- /dev/null +++ b/test/legacy_test/test_variable_length_memory_efficient_attention.py @@ -0,0 +1,296 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import unittest + +import numpy as np + +import paddle +from paddle import fluid +from paddle.framework import core +from paddle.incubate.nn.functional import ( + variable_length_memory_efficient_attention, +) +from paddle.static import Program, program_guard + +paddle.seed(2023) + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def create_attn_mask( + mask_type, + batch_size, + seq_lens, +): + max_seq_len = max(seq_lens) + mask = paddle.zeros( + [batch_size, 1, max_seq_len, max_seq_len], dtype=mask_type + ) + for i in range(batch_size): + seq_len = seq_lens[i] + mask[i, 0, :seq_len, :seq_len] = ( + paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) + - 1 + ) * 1e4 + return mask + + +def naive_attention_impl(query, key, value, mask, scale): + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(softmax_result, value) + return result + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.2", +) +class TestMemEffAttentionVariableAPI(unittest.TestCase): + def setUp(self): + self.name = "MemEffAPIVariable_fp32" + self.place = paddle.CUDAPlace(0) + self.batch_size = 1 + self.num_head = 8 + self.seq_len = 64 + self.dim_head = 16 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float32' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + def test_all(self): + paddle.disable_static() + + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale) + + out = variable_length_memory_efficient_attention( + q, + k, + v, + self.seq_lens, + self.seq_lens, + self.attention_mask, + self.scale, + ) + + for i in range(self.batch_size): + np.testing.assert_allclose( + out.numpy()[i, :, : self.seq_lens[i], :], + out_[i, :, : self.seq_lens[i], :], + rtol=5e-03, + atol=1e-03, + ) + + +class TestMemEffAPIVariableDtypeFP16(TestMemEffAttentionVariableAPI): + def setUp(self): + self.name = "MemEffAPIVariable_fp16" + self.place = paddle.CUDAPlace(0) + self.batch_size = 3 + self.num_head = 16 + self.seq_len = 64 + self.dim_head = 32 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + +class TestMemEffAPIVariableDtypeBF16(TestMemEffAttentionVariableAPI): + def setUp(self): + self.name = "MemEffAPIVariable_bf16" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 32 + self.dim_head = 128 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len // 2, + self.seq_len, + ], + "int32", + ) + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'bfloat16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.2", +) +class TestMemEffAPIVariableDtypeFP16Static(unittest.TestCase): + def setUp(self): + self.name = "MemEffAPIVariableStatic_fp16" + self.place = paddle.CUDAPlace(0) + self.batch_size = 3 + self.num_head = 16 + self.seq_len = 64 + self.dim_head = 32 + self.seq_lens = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ).numpy() + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ).numpy() + self.q = np.random.random(self.shape).astype(self.dtype) + self.k = np.random.random(self.shape).astype(self.dtype) + self.v = np.random.random(self.shape).astype(self.dtype) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + + self.ref_out = naive_attention_impl( + paddle.to_tensor(self.q), + paddle.to_tensor(self.k), + paddle.to_tensor(self.v), + paddle.to_tensor(self.attention_mask), + self.scale, + ) + + def test_all(self): + paddle.enable_static() + with program_guard(Program(), Program()): + q = paddle.static.data( + name="query", shape=self.shape, dtype=self.dtype + ) + k = paddle.static.data( + name="key", shape=self.shape, dtype=self.dtype + ) + v = paddle.static.data( + name="value", shape=self.shape, dtype=self.dtype + ) + mask = paddle.static.data( + name="mask", + shape=[self.batch_size, 1, self.seq_len, self.seq_len], + dtype=self.dtype, + ) + seq_lens = paddle.static.data( + name="seq_lens", shape=[self.batch_size, 1], dtype="int32" + ) + out = variable_length_memory_efficient_attention( + q, k, v, seq_lens, seq_lens, mask, self.scale + ) + exe = fluid.Executor() + res = exe.run( + feed={ + "query": self.q, + "key": self.k, + "value": self.v, + "seq_lens": self.seq_lens, + "mask": self.attention_mask, + }, + fetch_list=[out], + ) + paddle.disable_static() + np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03) + + +if __name__ == '__main__': + unittest.main() -- GitLab