From 61611786fa7e2607b51c19ed4d83a6a4d1a3d01d Mon Sep 17 00:00:00 2001 From: Chitsing KUI Date: Wed, 1 Mar 2023 23:03:33 +0800 Subject: [PATCH] Integration flash attention (#49869) * flash attn * seed * almost * softmax * fix workspace * add unitest; linux only * fix setup * fix datatype include * fix setup typo * fix def scope * new error api * use paddle fork * fix attr bug; complete ut * update flash hash * fix rng reset * fix offset * fix comments --- cmake/external/flashattn.cmake | 108 ++++++++++ cmake/third_party.cmake | 3 + paddle/phi/api/yaml/backward.yaml | 10 + paddle/phi/api/yaml/ops.yaml | 10 + paddle/phi/backends/dynload/CMakeLists.txt | 7 + paddle/phi/backends/dynload/dynamic_loader.cc | 14 ++ paddle/phi/backends/dynload/dynamic_loader.h | 1 + paddle/phi/backends/dynload/flashattn.cc | 28 +++ paddle/phi/backends/dynload/flashattn.h | 56 +++++ paddle/phi/infermeta/backward.cc | 17 ++ paddle/phi/infermeta/backward.h | 7 + paddle/phi/infermeta/ternary.cc | 12 ++ paddle/phi/infermeta/ternary.h | 8 + paddle/phi/kernels/CMakeLists.txt | 4 + paddle/phi/kernels/arange_kernel.h | 7 + paddle/phi/kernels/flash_attn_grad_kernel.h | 37 ++++ paddle/phi/kernels/flash_attn_kernel.h | 35 ++++ paddle/phi/kernels/gpu/arange_kernel.cu | 24 +++ .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 184 ++++++++++++++++ paddle/phi/kernels/gpu/flash_attn_kernel.cu | 188 +++++++++++++++++ python/env_dict.py.in | 1 + .../fluid/tests/unittests/CMakeLists.txt | 4 + .../tests/unittests/test_flash_attention.py | 196 ++++++++++++++++++ .../paddle/nn/functional/flash_attention.py | 123 +++++++++++ python/setup.py.in | 4 + setup.py | 5 + 26 files changed, 1093 insertions(+) create mode 100644 cmake/external/flashattn.cmake create mode 100644 paddle/phi/backends/dynload/flashattn.cc create mode 100644 paddle/phi/backends/dynload/flashattn.h create mode 100644 paddle/phi/kernels/flash_attn_grad_kernel.h create mode 100644 paddle/phi/kernels/flash_attn_kernel.h create mode 100644 paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/flash_attn_kernel.cu create mode 100644 python/paddle/fluid/tests/unittests/test_flash_attention.py create mode 100644 python/paddle/nn/functional/flash_attention.py diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake new file mode 100644 index 00000000000..83f0703ff8e --- /dev/null +++ b/cmake/external/flashattn.cmake @@ -0,0 +1,108 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + +add_definitions(-DPADDLE_WITH_FLASHATTN) + +set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) +set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) +set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) +set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) +set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc) + +set(FLASHATTN_INCLUDE_DIR + "${FLASHATTN_INSTALL_DIR}/include" + CACHE PATH "flash-attn Directory" FORCE) +set(FLASHATTN_LIB_DIR + "${FLASHATTN_INSTALL_DIR}/lib" + CACHE PATH "flash-attn Library Directory" FORCE) + +if(WIN32) + set(FLASHATTN_LIBRARIES + "${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "flash-attn Library" FORCE) +else() + set(FLASHATTN_LIBRARIES + "${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "flash-attn Library" FORCE) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" + OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" + OR WIN32) + set(USE_OMP OFF) +else() + set(USE_OMP ON) +endif() + +if(WIN32) + set(FLASHATTN_C_FLAGS $) + set(FLASHATTN_C_FLAGS_DEBUG + $) + set(FLASHATTN_C_FLAGS_RELEASE + $) + set(FLASHATTN_CXX_FLAGS $) + set(FLASHATTN_CXX_FLAGS_RELEASE + $) + set(FLASHATTN_CXX_FLAGS_DEBUG + $) +else() + set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS}) + set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) + set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) + set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) + set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) +endif() + +ExternalProject_Add( + extern_flashattn + ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} + GIT_REPOSITORY ${FLASHATTN_REPOSITORY} + GIT_TAG ${FLASHATTN_TAG} + PREFIX ${FLASHATTN_PREFIX_DIR} + SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR} + UPDATE_COMMAND "" + PATCH_COMMAND "" + #BUILD_ALWAYS 1 + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG} + -DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR} + -DWITH_GPU=${WITH_GPU} + -DWITH_ROCM=${WITH_ROCM} + -DWITH_OMP=${USE_OMP} + -DBUILD_SHARED=ON + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR} + BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES}) + +message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}") +get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY) +include_directories(${FLASHATTN_INCLUDE_DIR}) + +add_library(flashattn INTERFACE) +#set_property(TARGET flashattn PROPERTY IMPORTED_LOCATION ${FLASHATTN_LIBRARIES}) +add_dependencies(flashattn extern_flashattn) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index a42aa48f96a..0946efcb726 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -531,6 +531,9 @@ if(WITH_GPU include(external/cutlass) # download, build, install cusparselt list(APPEND third_party_deps extern_cutlass) set(WITH_CUTLASS ON) + include(external/flashattn) + list(APPEND third_party_deps extern_flashattn) + set(WITH_FLASHATTN ON) endif() endif() diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 8492da75eb2..dd964e82ad8 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -508,6 +508,16 @@ func : fill_diagonal_tensor_grad inplace : (out_grad -> x_grad) +- backward_op : flash_attn_grad + forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + infer_meta : + func : FlashAttnGradInferMeta + param : [q, k, v] + kernel : + func : flash_attn_grad + - backward_op : flip_grad forward : flip (Tensor x, int[] axis) -> Tensor(out) args : (Tensor out_grad, int[] axis) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 02edf19a75d..30e5f72e0c2 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -482,6 +482,16 @@ inplace : (x -> out) backward : fill_diagonal_tensor_grad +- op : flash_attn + args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) + output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset) + infer_meta : + func : FlashAttnInferMeta + param : [q, k, v] + kernel : + func : flash_attn + backward : flash_attn_grad + - op : flip args : (Tensor x, int[] axis) output : Tensor (out) diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index 73197846806..85826fe1cf7 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -92,6 +92,13 @@ if(WITH_MKLML) DEPS phi_dynamic_loader mklml) endif() +if(WITH_FLASHATTN) + cc_library( + phi_dynload_flashattn + SRCS flashattn.cc + DEPS phi_dynamic_loader flashattn) +endif() + cc_library( phi_dynload_lapack SRCS lapack.cc diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 82ea94ea683..c7869e7eea8 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -484,6 +484,20 @@ void* GetWarpRNNTDsoHandle() { #endif } +void* GetFlashAttnDsoHandle() { + std::string flashattn_dir = ""; + if (!s_py_site_pkg_path.path.empty()) { + flashattn_dir = s_py_site_pkg_path.path; + } +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.dylib"); +#elif defined(_WIN32) + return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn.dll"); +#else + return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.so"); +#endif +} + void* GetNCCLDsoHandle() { #ifdef PADDLE_WITH_HIP std::string warning_msg( diff --git a/paddle/phi/backends/dynload/dynamic_loader.h b/paddle/phi/backends/dynload/dynamic_loader.h index 96a484b7c09..c8dec39fa83 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.h +++ b/paddle/phi/backends/dynload/dynamic_loader.h @@ -36,6 +36,7 @@ void* GetNVRTCDsoHandle(); void* GetCUDADsoHandle(); void* GetWarpCTCDsoHandle(); void* GetWarpRNNTDsoHandle(); +void* GetFlashAttnDsoHandle(); void* GetNCCLDsoHandle(); void* GetHCCLDsoHandle(); void* GetTensorRtDsoHandle(); diff --git a/paddle/phi/backends/dynload/flashattn.cc b/paddle/phi/backends/dynload/flashattn.cc new file mode 100644 index 00000000000..83ff0601dab --- /dev/null +++ b/paddle/phi/backends/dynload/flashattn.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/dynload/flashattn.h" + +namespace phi { +namespace dynload { + +std::once_flag flashattn_dso_flag; +void* flashattn_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +FLASHATTN_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/flashattn.h b/paddle/phi/backends/dynload/flashattn.h new file mode 100644 index 00000000000..ec443fd9f8e --- /dev/null +++ b/paddle/phi/backends/dynload/flashattn.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include // NOLINT + +#include "flashattn/include/flash_attn.h" +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/port.h" + +namespace phi { +namespace dynload { + +extern std::once_flag flashattn_dso_flag; +extern void* flashattn_dso_handle; + +#define DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using flashattnFunc = decltype(&::__name); \ + std::call_once(flashattn_dso_flag, []() { \ + flashattn_dso_handle = phi::dynload::GetFlashAttnDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(flashattn_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \ + DYNAMIC_LOAD_FLASHATTN_WRAP(__name) + +#define FLASHATTN_ROUTINE_EACH(__macro) \ + __macro(flash_attn_fwd); \ + __macro(flash_attn_bwd); \ + __macro(flash_attn_error); + +FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP); + +#undef DYNAMIC_LOAD_FLASHATTN_WRAP + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 82183107da8..a27dfe29110 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -198,6 +198,23 @@ void CropGradInferMeta(const MetaTensor& out_grad, } } +void FlashAttnGradInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* dq, + MetaTensor* dk, + MetaTensor* dv) { + if (dq) { + dq->share_meta(q); + } + if (dk && k) { + dk->share_meta(k); + } + if (dv && v) { + dv->share_meta(v); + } +} + void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, const MetaTensor& softmax, const MetaTensor& loss_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 01cdc8023a1..2f72aeec086 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -168,6 +168,13 @@ void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad, int dim2, MetaTensor* x_grad); +void FlashAttnGradInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* dq, + MetaTensor* dk, + MetaTensor* dv); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index d790d226b2d..9f787d07753 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -255,6 +255,18 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, output_box->set_dtype(target_box.dtype()); } +void FlashAttnInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* out, + MetaTensor* softmax_lse, + MetaTensor* softmax, + MetaTensor* seed_offset) { + out->set_dims(q.dims()); + out->set_dtype(q.dtype()); + out->set_layout(q.layout()); +} + void ArangeInferMeta(const MetaTensor& start, const MetaTensor& end, const MetaTensor& step, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7f24f297009..338579930ae 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -63,6 +63,14 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, MetaTensor* output_box, MetaConfig config = MetaConfig()); +void FlashAttnInferMeta(const MetaTensor& q, + const MetaTensor& k, + const MetaTensor& v, + MetaTensor* out, + MetaTensor* softmax_lse, + MetaTensor* softmax, + MetaTensor* seed_offset); + void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index e1842510421..78d86d7ebf0 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -79,6 +79,10 @@ set(COMMON_KERNEL_DEPS utf8proc gather_scatter_functor) +if(WITH_FLASHATTN) + set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_dynload_flashattn) +endif() + set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) if(WITH_NCCL OR WITH_RCCL) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl) diff --git a/paddle/phi/kernels/arange_kernel.h b/paddle/phi/kernels/arange_kernel.h index be60824ac2b..6c879e27d79 100644 --- a/paddle/phi/kernels/arange_kernel.h +++ b/paddle/phi/kernels/arange_kernel.h @@ -25,4 +25,11 @@ void ArangeKernel(const Context& dev_ctx, const DenseTensor& step, DenseTensor* out); +template +void ArangeNullaryKernel(const Context& dev_ctx, + const T start, + const T end, + const T step, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h new file mode 100644 index 00000000000..92ec093b27a --- /dev/null +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void FlashAttnGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv); + +} // namespace phi diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h new file mode 100644 index 00000000000..6a633d13b24 --- /dev/null +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void FlashAttnKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax_lse, + DenseTensor* softmax, + DenseTensor* seed_offset); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index 4fafda857dc..cb8d30186ff 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -52,6 +52,30 @@ void ArangeKernel(const Context& dev_ctx, Range<<>>(start_value, step_value, size, out_data); } +template +void ArangeNullaryKernel(const Context& dev_ctx, + const T start_value, + const T end_value, + const T step_value, + DenseTensor* out) { + int64_t size = 0; + phi::funcs::GetSize(start_value, end_value, step_value, &size); + out->Resize(phi::make_ddim({size})); + T* out_data = dev_ctx.template Alloc(out); + + auto stream = dev_ctx.stream(); + int64_t block = std::min(size, static_cast(256)); + if (block == 0) { + return; + } + int64_t grid = (size + block - 1) / block; + Range<<>>(start_value, step_value, size, out_data); +} + +template decltype(ArangeNullaryKernel) + ArangeNullaryKernel; +template decltype(ArangeNullaryKernel) + ArangeNullaryKernel; } // namespace phi PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu new file mode 100644 index 00000000000..127d51562e5 --- /dev/null +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -0,0 +1,184 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/arange_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" + +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + +namespace phi { + +template +void FlashAttnGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + ctx.template Alloc(dq); + ctx.template Alloc(dk); + ctx.template Alloc(dv); + + cudaStream_t stream = ctx.stream(); + bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; + + // q,k,v [batch_size, seq_len, num_heads, head_dim] + + auto dims = q.dims(); + int64_t batch_size = dims[0]; + int64_t seq_len_q = dims[1]; + int64_t num_heads = dims[2]; + int64_t head_size = dims[3]; + + int64_t seq_len_k = k.dims()[1]; + + int64_t total_q = batch_size * seq_len_q; + int64_t total_k = batch_size * seq_len_k; + + DenseTensor q_t_s = + Reshape(ctx, q, {total_q, num_heads, head_size}); + DenseTensor k_t_s = + Reshape(ctx, k, {total_k, num_heads, head_size}); + DenseTensor v_t_s = + Reshape(ctx, v, {total_k, num_heads, head_size}); + + // q,k,v [total_*, num_heads, head_dim] + + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + + float scale = 1.0f / std::sqrt(head_size); + int num_splits = 0; // 0 for an internal heuristic, which is optimal + bool zero_tensors = false; + + std::vector seed_offset_vec; + phi::TensorToVector(seed_offset, ctx, &seed_offset_vec); + uint64_t seed = seed_offset_vec[0]; + uint64_t offset = seed_offset_vec[1]; + + DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); + + uint64_t workspace_size; + + // calculate workspace size before execution + bool succ = phi::dynload::flash_attn_bwd( + q_t_s.data(), + k_t_s.data(), + v_t_s.data(), + dq->data(), + dk->data(), + dv->data(), + nullptr, // for calculation workspace size + dout.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + seq_len_q, + seq_len_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + const_cast(softmax_lse.data()), + dsoftmax.data(), + nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_bwd( + q_t_s.data(), + k_t_s.data(), + v_t_s.data(), + dq->data(), + dk->data(), + dv->data(), + out.data(), + dout.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + seq_len_q, + seq_len_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + const_cast(softmax_lse.data()), + dsoftmax.data(), + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(flash_attn_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(5).SetBackend(phi::Backend::CPU); // seed_offset +} diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu new file mode 100644 index 00000000000..19079a3573f --- /dev/null +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -0,0 +1,188 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/flash_attn_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/phi/kernels/arange_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" + +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + +namespace phi { + +template +void FlashAttnKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + float dropout, + bool causal, + bool return_softmax, + DenseTensor* out, + DenseTensor* softmax_lse, + DenseTensor* softmax, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + ctx.template Alloc(out); + + cudaStream_t stream = ctx.stream(); + bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; + + // q,k,v [batch_size, seq_len, num_heads, head_dim] + + auto dims = q.dims(); + int64_t batch_size = dims[0]; + int64_t seq_len_q = dims[1]; + int64_t num_heads = dims[2]; + int64_t head_size = dims[3]; + + int64_t seq_len_k = k.dims()[1]; + + int64_t total_q = batch_size * seq_len_q; + int64_t total_k = batch_size * seq_len_k; + + DenseTensor q_t_s = + Reshape(ctx, q, {total_q, num_heads, head_size}); + DenseTensor k_t_s = + Reshape(ctx, k, {total_k, num_heads, head_size}); + DenseTensor v_t_s = + Reshape(ctx, v, {total_k, num_heads, head_size}); + + // q,k,v [total_*, num_heads, head_dim] + + DenseTensor cu_seqlens_q; + DenseTensor cu_seqlens_k; + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ArangeNullaryKernel( + ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + + float scale = 1.0f / std::sqrt(head_size); + int num_splits = 0; // 0 for an internal heuristic, which is optimal + bool zero_tensors = false; + + auto gen = ctx.GetGenerator(); + uint64_t inc = batch_size * num_heads * 32; + auto seed_offset_pair = gen->IncrementOffset(inc); + uint64_t seed = seed_offset_pair.first; + uint64_t offset = seed_offset_pair.second; + + std::vector seed_offset_vec{int64_t(seed), int64_t(offset)}; + phi::TensorFromVector(seed_offset_vec, ctx, seed_offset); + + softmax_lse->Resize({batch_size, num_heads, seq_len_q}); + ctx.template Alloc(softmax_lse); + + if (return_softmax) { + // may allocate more space than *seq_len_k* + int64_t blocksize_c = head_size > 64 ? 128 : 256; + int64_t max_len_k_ = + ((seq_len_k + blocksize_c - 1) / blocksize_c) * blocksize_c; + int64_t max_len_k = + seq_len_k <= 128 ? 128 : (seq_len_k <= 256 ? 256 : max_len_k_); + softmax->Resize({batch_size, num_heads, seq_len_q, max_len_k}); + ctx.template Alloc(softmax); + } + + uint64_t workspace_size; + + // calculate workspace size before execution + bool succ = + phi::dynload::flash_attn_fwd(q_t_s.data(), + k_t_s.data(), + v_t_s.data(), + nullptr, // for calculation workspace size + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + seq_len_q, + seq_len_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_fwd( + q_t_s.data(), + k_t_s.data(), + v_t_s.data(), + out->data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + seq_len_q, + seq_len_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(flash_attn, + GPU, + ALL_LAYOUT, + phi::FlashAttnKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/env_dict.py.in b/python/env_dict.py.in index 1e3a2b18114..a3140afd868 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -12,6 +12,7 @@ env_dict={ 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', + 'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@', 'LAPACK_LIB':'@LAPACK_LIB@', 'GFORTRAN_LIB':'@GFORTRAN_LIB@', 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2d4db9df69e..35c8fecaa74 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -486,6 +486,10 @@ if(NOT WITH_GPU list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) endif() +if(NOT WITH_FLASHATTN) + list(REMOVE_ITEM TEST_OPS test_flash_attention) +endif() + # Some ops need to check results when gc is enabled # Currently, only ops that register NoNeedBufferVarsInference need to do this test set(TEST_OPS_WITH_GC diff --git a/python/paddle/fluid/tests/unittests/test_flash_attention.py b/python/paddle/fluid/tests/unittests/test_flash_attention.py new file mode 100644 index 00000000000..1b3593c74a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_flash_attention.py @@ -0,0 +1,196 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import unittest + +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.nn.functional.flash_attention import flash_attention + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def attention_naive(q, k, v, causal=False): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = ( + paddle.incubate.softmax_mask_fuse_upper_triangle(s) + if causal + else F.softmax(s) + ) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", +) +class TestFlashAttentionAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.blocksize = 2 + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + + def test_all(self): + print( + f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) + # test dynamic + paddle.disable_static() + + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + out_ = attention_naive(q_, k_, v_, self.causal) + + out.backward() + out_.backward() + + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + self.assertEqual(q.grad.shape, q.shape) + self.assertEqual(q_.grad.shape, q.shape) + + np.testing.assert_allclose( + q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 + ) + + # test static + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + qs = paddle.static.data( + name="q", shape=self.shape, dtype=self.dtype + ) + ks = paddle.static.data( + name="k", shape=self.shape, dtype=self.dtype + ) + vs = paddle.static.data( + name="v", shape=self.shape, dtype=self.dtype + ) + + outs, softmax = flash_attention( + qs, ks, vs, self.dropout, self.causal, self.return_softmax + ) + + exe = fluid.Executor(self.place) + fetches_result = exe.run( + feed={ + "q": query.astype('float16'), + "k": key.astype('float16'), + "v": value.astype('float16'), + }, + fetch_list=[outs], + ) + + np.testing.assert_allclose( + fetches_result[0], out_, rtol=5e-03, atol=1e-03 + ) + + +class TestFlashAttentionAPITest1(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.blocksize = 2 + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + + +class TestFlashAttentionAPITest2(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 256, 8, 16) + self.blocksize = 2 + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = True + + +class TestFlashAttentionAPITest3(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 512, 8, 16) + self.blocksize = 2 + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = True + self.return_softmax = False + + +class TestFlashAttentionAPITest4(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.blocksize = 2 + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py new file mode 100644 index 00000000000..0bda34c2436 --- /dev/null +++ b/python/paddle/nn/functional/flash_attention.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _C_ops, in_dynamic_mode +from paddle.fluid.layer_helper import LayerHelper + + +def flash_attention( + query, + key, + value, + dropout=0.0, + causal=False, + return_softmax=False, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + query(Tensor): The query tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + key(Tensor): The key tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + value(Tensor): The value tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + dropout(float): The dropout ratio. + causal(bool): Wether enable causal mode. + return_softmax(bool): Wether to return softmax. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + softmax(Tensor): The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + # required: skiptest + import paddle + + q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + + output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False) + print(output) + """ + if in_dynamic_mode(): + ( + result_attention, + result_softmax_lse, + result_softmax, + seed_offset, + ) = _C_ops.flash_attn( + query, + key, + value, + dropout, + causal, + return_softmax, + ) + return result_attention, result_softmax + + helper = LayerHelper('flash_attn', **locals()) + dtype = helper.input_dtype(input_param_name='q') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'q': query, + 'k': key, + 'v': value, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + }, + ) + return out, softmax diff --git a/python/setup.py.in b/python/setup.py.in index 198599cb019..d75e8b08755 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -545,6 +545,10 @@ if not sys.platform.startswith("linux"): package_data['paddle.libs']+=[os.path.basename('${GNU_RT_LIB_2}')] shutil.copy('${GNU_RT_LIB_2}', libs_path) +if len('${FLASHATTN_LIBRARIES}') > 1: + package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_LIBRARIES}')] + shutil.copy('${FLASHATTN_LIBRARIES}', libs_path) + if '${WITH_MKL}' == 'ON': shutil.copy('${MKLML_SHARED_LIB}', libs_path) shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path) diff --git a/setup.py b/setup.py index 429db65282e..c53bd5f06c6 100644 --- a/setup.py +++ b/setup.py @@ -918,6 +918,11 @@ def get_package_data_and_package_dir(): shutil.copy(env_dict.get("OPENBLAS_LIB") + '.0', libs_path) package_data['paddle.libs'] += ['libopenblas.so.0'] + if len(env_dict.get("FLASHATTN_LIBRARIES", "")) > 1: + package_data['paddle.libs'] += [ + os.path.basename(env_dict.get("FLASHATTN_LIBRARIES")) + ] + shutil.copy(env_dict.get("FLASHATTN_LIBRARIES"), libs_path) if env_dict.get("WITH_LITE") == 'ON': shutil.copy(env_dict.get("LITE_SHARED_LIB"), libs_path) package_data['paddle.libs'] += [ -- GitLab