flashattn.cmake 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc)

set(FLASHATTN_INCLUDE_DIR
    "${FLASHATTN_INSTALL_DIR}/include"
    CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
    "${FLASHATTN_INSTALL_DIR}/lib"
    CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
  set(FLASHATTN_LIBRARIES
      "${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
      CACHE FILEPATH "flash-attn Library" FORCE)
else()
  set(FLASHATTN_LIBRARIES
      "${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
      CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
   OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
   OR WIN32)
  set(USE_OMP OFF)
else()
  set(USE_OMP ON)
endif()

if(WIN32)
  set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
  set(FLASHATTN_C_FLAGS_DEBUG
      $<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
  set(FLASHATTN_C_FLAGS_RELEASE
      $<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
  set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
  set(FLASHATTN_CXX_FLAGS_RELEASE
      $<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
  set(FLASHATTN_CXX_FLAGS_DEBUG
      $<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
  set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
  set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
  set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
  set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
  set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
  set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

ExternalProject_Add(
  extern_flashattn
  ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
  GIT_REPOSITORY ${FLASHATTN_REPOSITORY}
  GIT_TAG ${FLASHATTN_TAG}
  PREFIX ${FLASHATTN_PREFIX_DIR}
  SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
  UPDATE_COMMAND ""
  PATCH_COMMAND ""
  #BUILD_ALWAYS    1
  CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
             -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
             -DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
             -DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
             -DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
             -DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
             -DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
             -DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
             -DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
             -DWITH_GPU=${WITH_GPU}
             -DWITH_ROCM=${WITH_ROCM}
             -DWITH_OMP=${USE_OMP}
             -DBUILD_SHARED=ON
             -DCMAKE_POSITION_INDEPENDENT_CODE=ON
             -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
             ${EXTERNAL_OPTIONAL_ARGS}
  CMAKE_CACHE_ARGS
    -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
    -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
    -DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
  BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
include_directories(${FLASHATTN_INCLUDE_DIR})

add_library(flashattn INTERFACE)
#set_property(TARGET flashattn PROPERTY IMPORTED_LOCATION ${FLASHATTN_LIBRARIES})
add_dependencies(flashattn extern_flashattn)