未验证 提交 5d991c6f 编写于 作者: S sneaxiy 提交者: GitHub

Make flash attn v1 available (#56040)

* make flash attn v1 available

* add deps error

* refine cmake dependencies

* fix cmake error
上级 cc9a7688
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_V1_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_v1)
set(FLASHATTN_V1_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_V1_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn_v1)
set(FLASHATTN_V1_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_V1_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_V1_INCLUDE_DIR
"${FLASHATTN_V1_INSTALL_DIR}/include"
CACHE PATH "flash-attn v1 Directory" FORCE)
set(FLASHATTN_V1_LIB_DIR
"${FLASHATTN_V1_INSTALL_DIR}/lib"
CACHE PATH "flash-attn v1 Library Directory" FORCE)
if(WIN32)
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
else()
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()
if(WIN32)
set(FLASHATTN_V1_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_V1_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_V1_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_V1_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_V1_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
ExternalProject_Add(
extern_flashattn_v1
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${FLASHATTN_V1_REPOSITORY}
GIT_TAG ${FLASHATTN_V1_TAG}
PREFIX ${FLASHATTN_V1_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_V1_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_V1_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_V1_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_V1_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_V1_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_V1_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_V1_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_V1_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_V1_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_V1_LIBRARIES})
add_custom_target(
extern_flashattn_v1_move_lib
COMMAND ${CMAKE_COMMAND} -E copy ${FLASHATTN_V1_OLD_LIBRARIES}
${FLASHATTN_V1_LIBRARIES})
add_dependencies(extern_flashattn_v1_move_lib extern_flashattn_v1)
message(STATUS "flash-attn v1 library: ${FLASHATTN_V1_LIBRARIES}")
get_filename_component(FLASHATTN_V1_LIBRARY_PATH ${FLASHATTN_V1_LIBRARIES}
DIRECTORY)
include_directories(${FLASHATTN_V1_INCLUDE_DIR})
add_library(flashattn_v1 INTERFACE)
#set_property(TARGET flashattn_v1 PROPERTY IMPORTED_LOCATION ${FLASHATTN_V1_LIBRARIES})
add_dependencies(flashattn_v1 extern_flashattn_v1 extern_flashattn_v1_move_lib)
......@@ -516,7 +516,8 @@ if(WITH_GPU
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
include(external/flashattn_v1)
list(APPEND third_party_deps extern_flashattn extern_flashattn_v1)
set(WITH_FLASHATTN ON)
break()
endif()
......
......@@ -638,6 +638,28 @@
func : flash_attn_unpadded_grad
data_type: q
- backward_op : flash_attn_v1_grad
forward : flash_attn_v1 (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_grad
data_type: q
- backward_op : flash_attn_v1_unpadded_grad
forward : flash_attn_v1_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded_grad
data_type: q
- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
......
......@@ -703,6 +703,30 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad
- op : flash_attn_v1
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_v1_grad
- op : flash_attn_v1_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_v1_unpadded_grad
- op : flatten
args : (Tensor x, int start_axis = 1, int stop_axis = 1)
output : Tensor(out), Tensor(xshape)
......
......@@ -97,8 +97,8 @@ endif()
if(WITH_FLASHATTN)
cc_library(
phi_dynload_flashattn
SRCS flashattn.cc
DEPS phi_dynamic_loader flashattn)
SRCS flashattn.cc flashattn_v1.cc
DEPS phi_dynamic_loader flashattn flashattn_v1)
endif()
cc_library(
......
......@@ -500,6 +500,20 @@ void* GetFlashAttnDsoHandle() {
#endif
}
void* GetFlashAttnV1DsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn_v1.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.so");
#endif
}
void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP
std::string warning_msg(
......
......@@ -37,6 +37,7 @@ void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetFlashAttnV1DsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/dynload/flashattn_v1.h"
namespace phi {
namespace dynload {
std::once_flag flashattn_v1_dso_flag;
void* flashattn_v1_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name##__v1 __name##_v1
FLASHATTN_V1_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex> // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
namespace phi {
namespace dynload {
extern std::once_flag flashattn_v1_dso_flag;
extern void *flashattn_v1_dso_handle;
using flash_attn_fwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*cu_seqlens_q*/, // int32, batch_size+1, starting offset of
// each sequence
const void * /*cu_seqlens_k*/, // int32, batch_size+1, starting offset of
// each sequence
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/, // SMs per attention matrix, can be 1
void * /*softmax_lse_ptr*/, // softmax log_sum_exp
void * /*softmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);
using flash_attn_bwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dq*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
void * /*dk*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dv*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*dout*/, // total_q x num_heads, x head_size
const void * /*cu_seqlens_q*/, // int32, batch_size+1
const void * /*cu_seqlens_k*/, // int32, batch_size+1
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/,
void * /*softmax_lse_ptr*/,
void * /*dsoftmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);
using flash_attn_error_v1_func_t = const char *(*)();
#define DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
struct DynLoad__##__name##__v1 { \
template <typename... Args> \
auto operator()(Args... args) { \
using flashattnFunc = ::phi::dynload::__name##_v1_func_t; \
std::call_once(flashattn_v1_dso_flag, []() { \
flashattn_v1_dso_handle = phi::dynload::GetFlashAttnV1DsoHandle(); \
}); \
static void *p_##__name = dlsym(flashattn_v1_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name##__v1 __name##_v1
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name)
#define FLASHATTN_V1_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_error);
FLASHATTN_V1_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP);
#undef DYNAMIC_LOAD_FLASHATTN_V1_WRAP
} // namespace dynload
} // namespace phi
......@@ -219,6 +219,23 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}
void FlashAttnV1GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
if (dq) {
dq->share_meta(q);
}
if (dk && k) {
dk->share_meta(k);
}
if (dv && v) {
dv->share_meta(v);
}
}
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
......
......@@ -179,6 +179,13 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);
void FlashAttnV1GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
......
......@@ -269,6 +269,20 @@ void FlashAttnInferMeta(const MetaTensor& q,
out->set_layout(q.layout());
}
void FlashAttnV1InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
auto out_dims = q.dims();
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());
}
void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
......
......@@ -71,6 +71,14 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);
void FlashAttnV1InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);
void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
template <typename T, typename Context>
void FlashAttnV1GradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashV1AttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_v1_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn_v1.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd_v1(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd_v1(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
#endif
}
template <typename T, typename Context>
void FlashAttnV1GradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnV1UnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_v1_unpadded_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1UnpaddedGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
PD_REGISTER_KERNEL(flash_attn_v1_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1GradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn_v1.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(
dims.size(),
3,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_fwd_v1(
q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd_v1(
q.data(),
k.data(),
v.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
#endif
}
template <typename T, typename Context>
void FlashAttnV1Kernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(dims.size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnV1UnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
is_test,
out,
softmax,
softmax_lse,
seed_offset);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_v1_unpadded,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1UnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(flash_attn_v1,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1Kernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -13,6 +13,7 @@ env_dict={
'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@',
'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@',
'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@',
'FLASHATTN_V1_LIBRARIES': '@FLASHATTN_V1_LIBRARIES@',
'LAPACK_LIB':'@LAPACK_LIB@',
'GFORTRAN_LIB':'@GFORTRAN_LIB@',
'GNU_RT_LIB_1':'@GNU_RT_LIB_1@',
......
......@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
g_use_flash_attn_v1 = (
os.getenv('FLAGS_flash_attn_version', 'v2').strip().lower() == 'v1'
)
def flash_attention(
query,
......@@ -85,17 +91,28 @@ def flash_attention(
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
if g_use_flash_attn_v1:
(result_attention, result_softmax,) = _C_ops.flash_attn_v1(
query,
key,
value,
dropout,
causal,
return_softmax,
not training,
)
else:
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals())
......@@ -210,22 +227,38 @@ def flash_attn_unpadded(
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
if g_use_flash_attn_v1:
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
)
else:
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn_unpadded', **locals())
......
......@@ -599,6 +599,10 @@ if len('${FLASHATTN_LIBRARIES}') > 1:
package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_LIBRARIES}')]
shutil.copy('${FLASHATTN_LIBRARIES}', libs_path)
if len('${FLASHATTN_V1_LIBRARIES}') > 1:
package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_V1_LIBRARIES}')]
shutil.copy('${FLASHATTN_V1_LIBRARIES}', libs_path)
if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_SHARED_LIB}', libs_path)
shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)
......
......@@ -1036,6 +1036,13 @@ def get_package_data_and_package_dir():
os.path.basename(env_dict.get("FLASHATTN_LIBRARIES"))
]
shutil.copy(env_dict.get("FLASHATTN_LIBRARIES"), libs_path)
if len(env_dict.get("FLASHATTN_V1_LIBRARIES", "")) > 1:
package_data['paddle.libs'] += [
os.path.basename(env_dict.get("FLASHATTN_V1_LIBRARIES"))
]
shutil.copy(env_dict.get("FLASHATTN_V1_LIBRARIES"), libs_path)
if env_dict.get("WITH_LITE") == 'ON':
shutil.copy(env_dict.get("LITE_SHARED_LIB"), libs_path)
package_data['paddle.libs'] += [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册