未验证 提交 61611786 编写于 作者: C Chitsing KUI 提交者: GitHub

Integration flash attention (#49869)

* flash attn

* seed

* almost

* softmax

* fix workspace

* add unitest; linux only

* fix setup

* fix datatype include

* fix setup typo

* fix def scope

* new error api

* use paddle fork

* fix attr bug; complete ut

* update flash hash

* fix rng reset

* fix offset

* fix comments
上级 5751b7f4
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)
if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()
if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${FLASHATTN_REPOSITORY}
GIT_TAG ${FLASHATTN_TAG}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})
message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
include_directories(${FLASHATTN_INCLUDE_DIR})
add_library(flashattn INTERFACE)
#set_property(TARGET flashattn PROPERTY IMPORTED_LOCATION ${FLASHATTN_LIBRARIES})
add_dependencies(flashattn extern_flashattn)
...@@ -531,6 +531,9 @@ if(WITH_GPU ...@@ -531,6 +531,9 @@ if(WITH_GPU
include(external/cutlass) # download, build, install cusparselt include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass) list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON) set(WITH_CUTLASS ON)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif() endif()
endif() endif()
......
...@@ -508,6 +508,16 @@ ...@@ -508,6 +508,16 @@
func : fill_diagonal_tensor_grad func : fill_diagonal_tensor_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_grad
- backward_op : flip_grad - backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out) forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis) args : (Tensor out_grad, int[] axis)
......
...@@ -482,6 +482,16 @@ ...@@ -482,6 +482,16 @@
inplace : (x -> out) inplace : (x -> out)
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn
backward : flash_attn_grad
- op : flip - op : flip
args : (Tensor x, int[] axis) args : (Tensor x, int[] axis)
output : Tensor (out) output : Tensor (out)
......
...@@ -92,6 +92,13 @@ if(WITH_MKLML) ...@@ -92,6 +92,13 @@ if(WITH_MKLML)
DEPS phi_dynamic_loader mklml) DEPS phi_dynamic_loader mklml)
endif() endif()
if(WITH_FLASHATTN)
cc_library(
phi_dynload_flashattn
SRCS flashattn.cc
DEPS phi_dynamic_loader flashattn)
endif()
cc_library( cc_library(
phi_dynload_lapack phi_dynload_lapack
SRCS lapack.cc SRCS lapack.cc
......
...@@ -484,6 +484,20 @@ void* GetWarpRNNTDsoHandle() { ...@@ -484,6 +484,20 @@ void* GetWarpRNNTDsoHandle() {
#endif #endif
} }
void* GetFlashAttnDsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.so");
#endif
}
void* GetNCCLDsoHandle() { void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
std::string warning_msg( std::string warning_msg(
......
...@@ -36,6 +36,7 @@ void* GetNVRTCDsoHandle(); ...@@ -36,6 +36,7 @@ void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle(); void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle(); void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle(); void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetNCCLDsoHandle(); void* GetNCCLDsoHandle();
void* GetHCCLDsoHandle(); void* GetHCCLDsoHandle();
void* GetTensorRtDsoHandle(); void* GetTensorRtDsoHandle();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/dynload/flashattn.h"
namespace phi {
namespace dynload {
std::once_flag flashattn_dso_flag;
void* flashattn_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
FLASHATTN_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex> // NOLINT
#include "flashattn/include/flash_attn.h"
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
namespace phi {
namespace dynload {
extern std::once_flag flashattn_dso_flag;
extern void* flashattn_dso_handle;
#define DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using flashattnFunc = decltype(&::__name); \
std::call_once(flashattn_dso_flag, []() { \
flashattn_dso_handle = phi::dynload::GetFlashAttnDsoHandle(); \
}); \
static void* p_##__name = dlsym(flashattn_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
#undef DYNAMIC_LOAD_FLASHATTN_WRAP
} // namespace dynload
} // namespace phi
...@@ -198,6 +198,23 @@ void CropGradInferMeta(const MetaTensor& out_grad, ...@@ -198,6 +198,23 @@ void CropGradInferMeta(const MetaTensor& out_grad,
} }
} }
void FlashAttnGradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
if (dq) {
dq->share_meta(q);
}
if (dk && k) {
dk->share_meta(k);
}
if (dv && v) {
dv->share_meta(v);
}
}
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax, const MetaTensor& softmax,
const MetaTensor& loss_grad, const MetaTensor& loss_grad,
......
...@@ -168,6 +168,13 @@ void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad, ...@@ -168,6 +168,13 @@ void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad,
int dim2, int dim2,
MetaTensor* x_grad); MetaTensor* x_grad);
void FlashAttnGradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
void GatherNdGradInferMeta(const MetaTensor& x, void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index, const MetaTensor& index,
const MetaTensor& out_grad, const MetaTensor& out_grad,
......
...@@ -255,6 +255,18 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, ...@@ -255,6 +255,18 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
output_box->set_dtype(target_box.dtype()); output_box->set_dtype(target_box.dtype());
} }
void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* seed_offset) {
out->set_dims(q.dims());
out->set_dtype(q.dtype());
out->set_layout(q.layout());
}
void ArangeInferMeta(const MetaTensor& start, void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end, const MetaTensor& end,
const MetaTensor& step, const MetaTensor& step,
......
...@@ -63,6 +63,14 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, ...@@ -63,6 +63,14 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
MetaTensor* output_box, MetaTensor* output_box,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* seed_offset);
void InstanceNormInferMeta(const MetaTensor& x, void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale, const MetaTensor& scale,
const MetaTensor& bias, const MetaTensor& bias,
......
...@@ -79,6 +79,10 @@ set(COMMON_KERNEL_DEPS ...@@ -79,6 +79,10 @@ set(COMMON_KERNEL_DEPS
utf8proc utf8proc
gather_scatter_functor) gather_scatter_functor)
if(WITH_FLASHATTN)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_dynload_flashattn)
endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl)
......
...@@ -25,4 +25,11 @@ void ArangeKernel(const Context& dev_ctx, ...@@ -25,4 +25,11 @@ void ArangeKernel(const Context& dev_ctx,
const DenseTensor& step, const DenseTensor& step,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void ArangeNullaryKernel(const Context& dev_ctx,
const T start,
const T end,
const T step,
DenseTensor* out);
} // namespace phi } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset);
} // namespace phi
...@@ -52,6 +52,30 @@ void ArangeKernel(const Context& dev_ctx, ...@@ -52,6 +52,30 @@ void ArangeKernel(const Context& dev_ctx,
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data); Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
} }
template <typename T, typename Context>
void ArangeNullaryKernel(const Context& dev_ctx,
const T start_value,
const T end_value,
const T step_value,
DenseTensor* out) {
int64_t size = 0;
phi::funcs::GetSize(start_value, end_value, step_value, &size);
out->Resize(phi::make_ddim({size}));
T* out_data = dev_ctx.template Alloc<T>(out);
auto stream = dev_ctx.stream();
int64_t block = std::min(size, static_cast<int64_t>(256));
if (block == 0) {
return;
}
int64_t grid = (size + block - 1) / block;
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
}
template decltype(ArangeNullaryKernel<int64_t, phi::GPUContext>)
ArangeNullaryKernel;
template decltype(ArangeNullaryKernel<int, phi::GPUContext>)
ArangeNullaryKernel;
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
namespace phi {
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
float scale = 1.0f / std::sqrt(head_size);
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
std::vector<int64_t> seed_offset_vec;
phi::TensorToVector<int64_t>(seed_offset, ctx, &seed_offset_vec);
uint64_t seed = seed_offset_vec[0];
uint64_t offset = seed_offset_vec[1];
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::CPU); // seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
namespace phi {
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
DenseTensor* out,
DenseTensor* softmax_lse,
DenseTensor* softmax,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
DenseTensor q_t_s =
Reshape<T, Context>(ctx, q, {total_q, num_heads, head_size});
DenseTensor k_t_s =
Reshape<T, Context>(ctx, k, {total_k, num_heads, head_size});
DenseTensor v_t_s =
Reshape<T, Context>(ctx, v, {total_k, num_heads, head_size});
// q,k,v [total_*, num_heads, head_dim]
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
float scale = 1.0f / std::sqrt(head_size);
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
std::vector<int64_t> seed_offset_vec{int64_t(seed), int64_t(offset)};
phi::TensorFromVector<int64_t>(seed_offset_vec, ctx, seed_offset);
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *seq_len_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t max_len_k_ =
((seq_len_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
int64_t max_len_k =
seq_len_k <= 128 ? 128 : (seq_len_k <= 256 ? 256 : max_len_k_);
softmax->Resize({batch_size, num_heads, seq_len_q, max_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd(
q_t_s.data(),
k_t_s.data(),
v_t_s.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
seq_len_q,
seq_len_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
phi::FlashAttnKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -12,6 +12,7 @@ env_dict={ ...@@ -12,6 +12,7 @@ env_dict={
'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@',
'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@',
'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@',
'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@',
'LAPACK_LIB':'@LAPACK_LIB@', 'LAPACK_LIB':'@LAPACK_LIB@',
'GFORTRAN_LIB':'@GFORTRAN_LIB@', 'GFORTRAN_LIB':'@GFORTRAN_LIB@',
'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@',
......
...@@ -486,6 +486,10 @@ if(NOT WITH_GPU ...@@ -486,6 +486,10 @@ if(NOT WITH_GPU
list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass)
endif() endif()
if(NOT WITH_FLASHATTN)
list(REMOVE_ITEM TEST_OPS test_flash_attention)
endif()
# Some ops need to check results when gc is enabled # Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test # Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC set(TEST_OPS_WITH_GC
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.nn.functional.flash_attention import flash_attention
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def attention_naive(q, k, v, causal=False):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
p = (
paddle.incubate.softmax_mask_fuse_upper_triangle(s)
if causal
else F.softmax(s)
)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
)
class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 16)
self.blocksize = 2
self.dtype = 'float16'
self.dropout = 0.0
self.causal = False
self.return_softmax = False
def test_all(self):
print(
f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
# test dynamic
paddle.disable_static()
query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
out_ = attention_naive(q_, k_, v_, self.causal)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
self.assertEqual(q.grad.shape, q.shape)
self.assertEqual(q_.grad.shape, q.shape)
np.testing.assert_allclose(
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)
# test static
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
qs = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
ks = paddle.static.data(
name="k", shape=self.shape, dtype=self.dtype
)
vs = paddle.static.data(
name="v", shape=self.shape, dtype=self.dtype
)
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
)
exe = fluid.Executor(self.place)
fetches_result = exe.run(
feed={
"q": query.astype('float16'),
"k": key.astype('float16'),
"v": value.astype('float16'),
},
fetch_list=[outs],
)
np.testing.assert_allclose(
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)
class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 128, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 256, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = True
class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 512, 8, 16)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = True
self.return_softmax = False
class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.blocksize = 2
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
def flash_attention(
query,
key,
value,
dropout=0.0,
causal=False,
return_softmax=False,
name=None,
):
r"""
The equation is:
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Warning:
This API is only support inputs with dtype float16 and bfloat16.
Args:
query(Tensor): The query tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
dropout(float): The dropout ratio.
causal(bool): Wether enable causal mode.
return_softmax(bool): Wether to return softmax.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
softmax(Tensor): The softmax tensor. None if return_softmax is False.
Examples:
.. code-block:: python
# required: skiptest
import paddle
q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16)
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output)
"""
if in_dynamic_mode():
(
result_attention,
result_softmax_lse,
result_softmax,
seed_offset,
) = _C_ops.flash_attn(
query,
key,
value,
dropout,
causal,
return_softmax,
)
return result_attention, result_softmax
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(paddle.float32)
seed_offset = helper.create_variable_for_type_inference(paddle.int64)
inputs = {
'q': query,
'k': key,
'v': value,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
},
)
return out, softmax
...@@ -545,6 +545,10 @@ if not sys.platform.startswith("linux"): ...@@ -545,6 +545,10 @@ if not sys.platform.startswith("linux"):
package_data['paddle.libs']+=[os.path.basename('${GNU_RT_LIB_2}')] package_data['paddle.libs']+=[os.path.basename('${GNU_RT_LIB_2}')]
shutil.copy('${GNU_RT_LIB_2}', libs_path) shutil.copy('${GNU_RT_LIB_2}', libs_path)
if len('${FLASHATTN_LIBRARIES}') > 1:
package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_LIBRARIES}')]
shutil.copy('${FLASHATTN_LIBRARIES}', libs_path)
if '${WITH_MKL}' == 'ON': if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_SHARED_LIB}', libs_path) shutil.copy('${MKLML_SHARED_LIB}', libs_path)
shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path) shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)
......
...@@ -918,6 +918,11 @@ def get_package_data_and_package_dir(): ...@@ -918,6 +918,11 @@ def get_package_data_and_package_dir():
shutil.copy(env_dict.get("OPENBLAS_LIB") + '.0', libs_path) shutil.copy(env_dict.get("OPENBLAS_LIB") + '.0', libs_path)
package_data['paddle.libs'] += ['libopenblas.so.0'] package_data['paddle.libs'] += ['libopenblas.so.0']
if len(env_dict.get("FLASHATTN_LIBRARIES", "")) > 1:
package_data['paddle.libs'] += [
os.path.basename(env_dict.get("FLASHATTN_LIBRARIES"))
]
shutil.copy(env_dict.get("FLASHATTN_LIBRARIES"), libs_path)
if env_dict.get("WITH_LITE") == 'ON': if env_dict.get("WITH_LITE") == 'ON':
shutil.copy(env_dict.get("LITE_SHARED_LIB"), libs_path) shutil.copy(env_dict.get("LITE_SHARED_LIB"), libs_path)
package_data['paddle.libs'] += [ package_data['paddle.libs'] += [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册