...
 
Commits (11)
    https://gitcode.net/paddlepaddle/Paddle/-/commit/42ab2c34b3dd76b54b547613e12393413350c285 Fix test_resnet and test_resnet_v2 ut (#55723) (#55863) 2023-08-02T11:16:22+08:00 WangZhen 23097963+0x45f@users.noreply.github.com * Fix test_resnet and test_resnet_v2 ut * Remove ut https://gitcode.net/paddlepaddle/Paddle/-/commit/7b7ec08eb285c91880864603a46b05e79f51841e [BugFix]Fix bug in vpp+ sharding/dp overlap (#55890) 2023-08-02T17:05:09+08:00 ShenLiang 1422485404@qq.com * fix bug * fix bug * fix bug * fix bug * fix bug https://gitcode.net/paddlepaddle/Paddle/-/commit/559d677a1551b40f54a9f2c8148043a446b74ac6 add fleet test tools. (#55701) 2023-08-02T20:15:59+08:00 wuhuachaocoding 77733235+wuhuachaocoding@users.noreply.github.com https://gitcode.net/paddlepaddle/Paddle/-/commit/8d3a9882ce863d230a817bfdb5eaaab2fb883eb4 cherry-pick fused_rope from develop (#55931) 2023-08-07T11:35:17+08:00 niuliling123 51102941+niuliling123@users.noreply.github.com * Add fused_rope forward op (#54351) * style * more * update ctest * Update legacy_backward.yaml * Update legacy_ops.yaml * Update legacy_ops.yaml * update * update * update for move * Update the rope op according to the comments (#54985) * Update multiary.cc * Update __init__.py * for int64_t and assert * more * remove useless assert first --------- Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:sneaxiy@126.com" title="sneaxiy@126.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg3" style="text-decoration: none">N</a><a href="mailto:sneaxiy@126.com" title="sneaxiy@126.com">sneaxiy</a> &lt;<a href="mailto:sneaxiy@126.com" title="sneaxiy@126.com">sneaxiy@126.com</a>&gt;</span> https://gitcode.net/paddlepaddle/Paddle/-/commit/cc9a7688c3e5591df19c0be53e9d04d25b5cd1e0 [cherry-pick] Integration flash attention 2 (#56015) 2023-08-07T19:27:49+08:00 umiswing umiswing@foxmail.com * [FlashAttn] add flash randomness control (#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:kuizhiqing@msn.com" title="kuizhiqing@msn.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg4" style="text-decoration: none">N</a><a href="mailto:kuizhiqing@msn.com" title="kuizhiqing@msn.com">Chitsing KUI</a> &lt;<a href="mailto:kuizhiqing@msn.com" title="kuizhiqing@msn.com">kuizhiqing@msn.com</a>&gt;</span> https://gitcode.net/paddlepaddle/Paddle/-/commit/5d991c6fc95c94169d24f6ec5d281c792624b6f9 Make flash attn v1 available (#56040) 2023-08-08T13:04:18+08:00 sneaxiy 32832641+sneaxiy@users.noreply.github.com * make flash attn v1 available * add deps error * refine cmake dependencies * fix cmake error https://gitcode.net/paddlepaddle/Paddle/-/commit/6131aebce4e40234ad681e6d684994c9ac70b79b fix bug for fused_linear_grad_add and main_grad (#56030) (#56071) 2023-08-09T07:06:36+08:00 Yuang Liu liuyuang@baidu.com https://gitcode.net/paddlepaddle/Paddle/-/commit/caa0f3774cdd93b766f048bfcda16ddde6b06b67 fix codestyle (#56066) 2023-08-09T09:55:15+08:00 ShenLiang 1422485404@qq.com https://gitcode.net/paddlepaddle/Paddle/-/commit/9b317b2d484f8420feb05deb225812189b5f62b7 Add assert for static and other plateform (#56044) 2023-08-09T13:46:55+08:00 niuliling123 51102941+niuliling123@users.noreply.github.com https://gitcode.net/paddlepaddle/Paddle/-/commit/77da91068dc0c0cf90ac9f41bf9d8b8ab3348261 fused linear grad add bug fix and perf optim (#56094) 2023-08-09T14:01:58+08:00 Yuang Liu liuyuang@baidu.com * skip CopyOrAdd when tmp grad is None (#55679) * Optim fused linear grad add (#55927) https://gitcode.net/paddlepaddle/Paddle/-/commit/75cc7057454a557a0136d47e6391c9c8d17b670a dp and sharding coexist (#56096) 2023-08-10T18:50:27+08:00 liuzhenhai93 liuzhenhai93@outlook.com * dp and sharding coexist * dp
......@@ -17,10 +17,10 @@ include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_TAG b5bdb79d5e1f2f88b1ef62e86899a14f82fa079a)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
......@@ -62,7 +62,7 @@ else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
......@@ -93,6 +93,8 @@ ExternalProject_Add(
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include(ExternalProject)
add_definitions(-DPADDLE_WITH_FLASHATTN)
set(FLASHATTN_V1_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_v1)
set(FLASHATTN_V1_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_V1_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn_v1)
set(FLASHATTN_V1_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_V1_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_V1_INCLUDE_DIR
"${FLASHATTN_V1_INSTALL_DIR}/include"
CACHE PATH "flash-attn v1 Directory" FORCE)
set(FLASHATTN_V1_LIB_DIR
"${FLASHATTN_V1_INSTALL_DIR}/lib"
CACHE PATH "flash-attn v1 Library Directory" FORCE)
if(WIN32)
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
else()
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()
if(WIN32)
set(FLASHATTN_V1_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_V1_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_V1_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_V1_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_V1_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
ExternalProject_Add(
extern_flashattn_v1
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${FLASHATTN_V1_REPOSITORY}
GIT_TAG ${FLASHATTN_V1_TAG}
PREFIX ${FLASHATTN_V1_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_V1_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_V1_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_V1_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_V1_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_V1_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_V1_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_V1_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_V1_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_V1_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_V1_LIBRARIES})
add_custom_target(
extern_flashattn_v1_move_lib
COMMAND ${CMAKE_COMMAND} -E copy ${FLASHATTN_V1_OLD_LIBRARIES}
${FLASHATTN_V1_LIBRARIES})
add_dependencies(extern_flashattn_v1_move_lib extern_flashattn_v1)
message(STATUS "flash-attn v1 library: ${FLASHATTN_V1_LIBRARIES}")
get_filename_component(FLASHATTN_V1_LIBRARY_PATH ${FLASHATTN_V1_LIBRARIES}
DIRECTORY)
include_directories(${FLASHATTN_V1_INCLUDE_DIR})
add_library(flashattn_v1 INTERFACE)
#set_property(TARGET flashattn_v1 PROPERTY IMPORTED_LOCATION ${FLASHATTN_V1_LIBRARIES})
add_dependencies(flashattn_v1 extern_flashattn_v1 extern_flashattn_v1_move_lib)
......@@ -512,10 +512,16 @@ if(WITH_GPU
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.2)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
include(external/flashattn_v1)
list(APPEND third_party_deps extern_flashattn extern_flashattn_v1)
set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
endif()
endif()
......
......@@ -124,7 +124,10 @@ GradNodeAccumulation::operator()(
if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
if (grad_out.defined() && grad_out.initialized()) {
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
}
// else { do nothing since there is no valid value in grad out tensor }
is_fake_empty_ = false;
}
......
......@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -628,7 +628,7 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -638,6 +638,28 @@
func : flash_attn_unpadded_grad
data_type: q
- backward_op : flash_attn_v1_grad
forward : flash_attn_v1 (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_grad
data_type: q
- backward_op : flash_attn_v1_unpadded_grad
forward : flash_attn_v1_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded_grad
data_type: q
- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
......
......@@ -15,3 +15,15 @@
func : fused_dropout_add_grad
data_type : out_grad
support_dygraph_mode : true
- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
func : fused_rotary_position_embedding_grad
data_type : out_q_grad
support_dygraph_mode : true
......@@ -45,7 +45,7 @@
support_dygraph_mode : true
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
......@@ -65,6 +65,18 @@
data_type : x
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask
- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
backward: fused_rotary_position_embedding_grad
support_dygraph_mode : true
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
......
......@@ -678,20 +678,21 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -701,6 +702,29 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad
- op : flash_attn_v1
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1
data_type : q
backward : flash_attn_v1_grad
- op : flash_attn_v1_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_v1_unpadded_grad
- op : flatten
args : (Tensor x, int start_axis = 1, int stop_axis = 1)
output : Tensor(out), Tensor(xshape)
......
......@@ -97,8 +97,8 @@ endif()
if(WITH_FLASHATTN)
cc_library(
phi_dynload_flashattn
SRCS flashattn.cc
DEPS phi_dynamic_loader flashattn)
SRCS flashattn.cc flashattn_v1.cc
DEPS phi_dynamic_loader flashattn flashattn_v1)
endif()
cc_library(
......
......@@ -500,6 +500,20 @@ void* GetFlashAttnDsoHandle() {
#endif
}
void* GetFlashAttnV1DsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn_v1.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.so");
#endif
}
void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP
std::string warning_msg(
......
......@@ -37,6 +37,7 @@ void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetFlashAttnV1DsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
......
......@@ -43,9 +43,13 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/dynload/flashattn_v1.h"
namespace phi {
namespace dynload {
std::once_flag flashattn_v1_dso_flag;
void* flashattn_v1_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name##__v1 __name##_v1
FLASHATTN_V1_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex> // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
namespace phi {
namespace dynload {
extern std::once_flag flashattn_v1_dso_flag;
extern void *flashattn_v1_dso_handle;
using flash_attn_fwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*cu_seqlens_q*/, // int32, batch_size+1, starting offset of
// each sequence
const void * /*cu_seqlens_k*/, // int32, batch_size+1, starting offset of
// each sequence
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/, // SMs per attention matrix, can be 1
void * /*softmax_lse_ptr*/, // softmax log_sum_exp
void * /*softmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);
using flash_attn_bwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dq*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
void * /*dk*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dv*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*dout*/, // total_q x num_heads, x head_size
const void * /*cu_seqlens_q*/, // int32, batch_size+1
const void * /*cu_seqlens_k*/, // int32, batch_size+1
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/,
void * /*softmax_lse_ptr*/,
void * /*dsoftmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);
using flash_attn_error_v1_func_t = const char *(*)();
#define DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
struct DynLoad__##__name##__v1 { \
template <typename... Args> \
auto operator()(Args... args) { \
using flashattnFunc = ::phi::dynload::__name##_v1_func_t; \
std::call_once(flashattn_v1_dso_flag, []() { \
flashattn_v1_dso_handle = phi::dynload::GetFlashAttnV1DsoHandle(); \
}); \
static void *p_##__name = dlsym(flashattn_v1_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name##__v1 __name##_v1
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name)
#define FLASHATTN_V1_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_error);
FLASHATTN_V1_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP);
#undef DYNAMIC_LOAD_FLASHATTN_V1_WRAP
} // namespace dynload
} // namespace phi
......@@ -219,6 +219,23 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}
void FlashAttnV1GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
if (dq) {
dq->share_meta(q);
}
if (dk && k) {
dk->share_meta(k);
}
if (dv && v) {
dv->share_meta(v);
}
}
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
......@@ -1175,4 +1192,32 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
auto input_dims = dout_q.dims();
PADDLE_ENFORCE_EQ(
input_dims.size(),
4,
phi::errors::InvalidArgument("Input should be a 4-D tensor of format "
"[batch_size, seq_len, num_heads, head_dim],"
"but got %u.",
input_dims.size()));
if (dout_q) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
}
} // namespace phi
......@@ -179,11 +179,25 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);
void FlashAttnV1GradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* y_grad);
void FusedRopeGradInferMeta(const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -1259,6 +1259,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
bool has_bias,
MetaTensor* dweight_out,
MetaTensor* dbias_out) {
const auto dtype = dout.dtype();
......@@ -1302,7 +1303,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
? DataType::FLOAT32
: dtype;
if (dbias_out) {
if (has_bias && dbias_out) {
dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
}
......@@ -3227,6 +3228,33 @@ void FusedAdamInferMeta(
}
}
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
auto input_dims = q.dims();
PADDLE_ENFORCE_EQ(input_dims.size(),
4,
phi::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.",
input_dims.size()));
if (q) {
out_q->set_dims(q.dims());
out_q->set_dtype(q.dtype());
}
if (k) {
out_k->set_dims(k.dims());
out_k->set_dtype(k.dtype());
}
if (v) {
out_v->set_dims(v.dims());
out_v->set_dtype(v.dtype());
}
}
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
......
......@@ -265,6 +265,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
bool has_bias,
MetaTensor* dweight_out,
MetaTensor* dbias_out);
......@@ -605,4 +606,11 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
} // namespace phi
......@@ -269,6 +269,20 @@ void FlashAttnInferMeta(const MetaTensor& q,
out->set_layout(q.layout());
}
void FlashAttnV1InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
auto out_dims = q.dims();
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());
}
void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
......
......@@ -71,6 +71,14 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);
void FlashAttnV1InferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);
void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
......
......@@ -20,33 +20,38 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
template <typename T, typename Context>
void FlashAttnV1GradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashV1AttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
} // namespace phi
......@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
int64_t K,
int64_t N,
bool use_addto,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
......@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
use_addto);
}
if (dbias_out == nullptr) return;
if (!has_bias) return;
if (!fuse_bias_grad) {
auto dout_copy = dout;
......@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
multi_precision = false;
}
if (dbias_out) {
if (has_bias && dbias_out) {
ctx.template Alloc<T>(dbias_out);
}
......@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
PrintMeta<kLogLevel>(dweight_out, "dweight_out");
PrintMeta<kLogLevel>(dbias_out, "dbias_out");
VLOG(kLogLevel) << "multi_precision = " << multi_precision;
VLOG(kLogLevel) << "has_bias = " << has_bias;
VLOG(kLogLevel) << "use_addto = " << use_addto;
VLOG(kLogLevel) << "M = " << M;
VLOG(kLogLevel) << "N = " << N;
......@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx,
}
if (multi_precision) {
FusedLinearParamGradAddImpl<T, MT, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
FusedLinearParamGradAddImpl<T, MT, Context>(ctx,
x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} else {
FusedLinearParamGradAddImpl<T, T, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
FusedLinearParamGradAddImpl<T, T, Context>(ctx,
x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
}
}
......@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
PADDLE_THROW(phi::errors::Unimplemented(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeGradKernel(phi::Array<const T*, 3> ins_data,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int64_t nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0 + sin_value[ls_index] * p1;
result[ls_index] = cos_value[ls_index] * p1 - sin_value[pr_index] * p0;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
int64_t numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
// small size for broadcast
auto batch_size = dout_q.dims()[0];
auto num_heads = dout_q.dims()[2];
auto head_dim = dout_q.dims()[3];
auto seq_len = dout_q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));
constexpr const int vec_size = 2;
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
int num_inputs = 0;
if (dout_k.get_ptr()) {
dev_ctx.template Alloc<T>(dk);
outs_data[1] = dk->data<T>();
ins_data[1] = dout_k->data<T>();
num_inputs++;
}
if (dout_v.get_ptr()) {
dev_ctx.template Alloc<T>(dv);
outs_data[2] = dv->data<T>();
ins_data[2] = dout_v->data<T>();
num_inputs++;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
VectorizedFusedRopeGradKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_rotary_position_embedding_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedRopeGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16){};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace phi {
namespace fusion {
template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
int64_t batch_size,
int64_t seq_len,
int64_t num_heads,
int64_t head_dim,
phi::Array<T*, 3> outs_data,
int num_inputs,
MPType div_c) {
int64_t index =
(static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x) +
threadIdx.x) *
VecSize;
int64_t stride = static_cast<int64_t>(gridDim.x) *
static_cast<int64_t>(blockDim.x) * VecSize;
int64_t size = batch_size * seq_len * num_heads * head_dim;
MPType sin_value[VecSize];
MPType cos_value[VecSize];
MPType result[VecSize];
T store[VecSize];
using VecType = phi::AlignedVector<T, VecSize>;
constexpr int kVectorsPerThread = VecSize / 2;
for (; index < size; index += stride) {
#pragma unroll
for (int64_t nx = 0; nx < VecSize; ++nx) {
// get sin_index and cos_index
int64_t index_wc = (index + nx) % (seq_len * num_heads * head_dim);
int64_t pos_seq = index_wc / (num_heads * head_dim);
MPType idx = static_cast<MPType>((index_wc % head_dim) / 2 * 2.0);
MPType indicses =
static_cast<MPType>(1) /
pow(static_cast<MPType>(10000), idx * static_cast<MPType>(div_c));
MPType value = pos_seq * indicses;
sin_value[nx] = sin(value);
cos_value[nx] = cos(value);
}
#pragma unroll
for (int iter = 0; iter < 3; iter++) {
if (iter > num_inputs) break;
const T* input = ins_data[iter] + index;
VecType* out = reinterpret_cast<VecType*>(outs_data[iter] + index);
#pragma unroll
for (int nx = 0; nx < kVectorsPerThread; ++nx) {
int pr_index = nx * 2;
int ls_index = pr_index + 1;
MPType p0 = static_cast<MPType>(input[pr_index]);
MPType p1 = static_cast<MPType>(input[ls_index]);
result[pr_index] = cos_value[pr_index] * p0;
result[pr_index] -= sin_value[pr_index] * p1;
result[ls_index] = sin_value[ls_index] * p0;
result[ls_index] += cos_value[ls_index] * p1;
store[pr_index] = static_cast<T>(result[pr_index]);
store[ls_index] = static_cast<T>(result[ls_index]);
}
out[0] = *(reinterpret_cast<VecType*>(store));
}
}
}
template <typename T, typename Context>
void FusedRopeKernel(const Context& dev_ctx,
const DenseTensor& q,
const paddle::optional<DenseTensor>& k,
const paddle::optional<DenseTensor>& v,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
// small size for broadcast
auto batch_size = q.dims()[0];
auto num_heads = q.dims()[2];
auto head_dim = q.dims()[3];
auto seq_len = q.dims()[1];
PADDLE_ENFORCE_NE(head_dim % 2,
1,
phi::errors::InvalidArgument(
"The head_dim of input must be a multiple of 2."));
constexpr const int vec_size = 2;
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int64_t grid = config.block_per_grid.x;
int64_t block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
int num_inputs = 0;
if (k.get_ptr()) {
dev_ctx.template Alloc<T>(out_k);
ins_data[1] = k->data<T>();
outs_data[1] = out_k->data<T>();
num_inputs++;
}
if (v.get_ptr()) {
dev_ctx.template Alloc<T>(out_v);
ins_data[2] = v->data<T>();
outs_data[2] = out_v->data<T>();
num_inputs++;
}
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_rotary_position_embedding,
GPU,
ALL_LAYOUT,
phi::fusion::FusedRopeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16){};
......@@ -25,6 +25,7 @@
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -55,116 +56,89 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
const cudaStream_t stream = ctx.stream();
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const bool zero_tensors = false;
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -186,52 +160,86 @@ void FlashAttnGradKernel(const Context& ctx,
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
head_size_og,
head_size,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnBwdParamsV2 params =
FlashAttnBwdParamsV2(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
q.dtype(),
seed_offset.data<int64_t>());
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
FlashAttnUnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......
......@@ -21,13 +21,13 @@
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
DECLARE_bool(cudnn_deterministic);
......@@ -35,30 +35,31 @@ DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
......@@ -69,126 +70,79 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ =
phi::dynload::flash_attn_fwd(q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd(
const int64_t total_q = dims[0];
const int num_heads = dims[1];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];
const int batch_size = cu_seqlens_q.numel() - 1;
// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
// TODO(umiswing): add shape check
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
const bool succ = phi::dynload::flash_attn_varlen_fwd(
q.data(),
k.data(),
v.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
params.return_softmax ? softmax->data() : nullptr,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
seed,
offset);
params.seed,
params.offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -197,10 +151,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -215,51 +171,81 @@ void FlashAttnKernel(const Context& ctx,
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
// TODO(umiswing): Add check shape
const float scale = 1.0f / std::sqrt(head_size);
FlashAttnFwdParamsV2<T> params =
FlashAttnFwdParamsV2<T>(ctx,
batch_size,
seqlen_q,
seqlen_k,
num_heads,
num_heads_k,
head_size,
dropout,
scale,
causal,
return_softmax,
q.dtype(),
is_test,
rng_name,
fixed_seed_offset.get_ptr(),
softmax,
softmax_lse,
seed_offset);
int64_t seq_len_k = k.dims()[1];
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
ctx.template Alloc<T>(out);
float scale = 1.0f / std::sqrt(head_size);
cudaStream_t stream = ctx.stream();
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
VLOG(4) << "FlashAttn fwd seed: " << params.seed
<< ", offset: " << params.offset;
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnUnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
is_test,
out,
softmax,
softmax_lse,
seed_offset);
bool succ = phi::dynload::flash_attn_fwd(
q.data(),
k.data(),
v.data(),
params.rng_state.data(),
out->data(),
params.return_softmax ? params.softmax->data() : nullptr,
params.softmax_lse->data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.return_softmax,
params.is_bf16,
stream,
params.seed,
params.offset);
PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
#endif
}
......@@ -270,11 +256,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
phi::FlashAttnKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
template <typename T>
struct FlashAttnFwdParamsV2 {
int batch_size;
// for padded kernel, max_seqlen_q and seqlen_q is the same.
int64_t max_seqlen_q;
// for padded kernel, max_seqlen_k and seqlen_k is the same.
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool return_softmax;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor rng_state;
DenseTensor* softmax;
DenseTensor* softmax_lse;
DenseTensor* seed_offset;
FlashAttnFwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const bool _return_softmax,
const DataType q_dtype,
const bool is_test,
const std::string& rng_name,
const DenseTensor* const fixed_seed_offset_ptr,
DenseTensor* _softmax,
DenseTensor* _softmax_lse,
DenseTensor* _seed_offset)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
causal(_causal),
return_softmax(_return_softmax),
softmax(_softmax),
softmax_lse(_softmax_lse),
seed_offset(_seed_offset) {
dropout = is_test ? 0.0f : _dropout;
is_bf16 = q_dtype == DataType::BFLOAT16;
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
if (fixed_seed_offset_ptr) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset_ptr->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_lse->Resize({batch_size, num_heads, max_seqlen_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
softmax->Resize(
{batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded});
ctx.template Alloc<T>(softmax);
}
}
};
struct FlashAttnBwdParamsV2 {
int batch_size;
int64_t max_seqlen_q;
int64_t max_seqlen_k;
int seqlen_q_rounded;
int seqlen_k_rounded;
int num_heads;
int num_heads_k;
int head_size;
int head_size_rounded;
float dropout;
float scale;
bool causal;
bool is_bf16;
uint64_t seed;
uint64_t offset;
DenseTensor softmax_d;
DenseTensor dq_accum;
DenseTensor rng_state;
FlashAttnBwdParamsV2(const GPUContext& ctx,
const int _batch_size,
const int64_t _max_seqlen_q,
const int64_t _max_seqlen_k,
const int _num_heads,
const int _num_heads_k,
const int _head_size,
const float _dropout,
const float _scale,
const bool _causal,
const DataType q_dtype,
const int64_t* seed_offset_data)
: batch_size(_batch_size),
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
dropout(_dropout),
scale(_scale),
causal(_causal) {
is_bf16 = q_dtype == DataType::BFLOAT16;
seed = static_cast<uint64_t>(seed_offset_data[0]);
offset = static_cast<uint64_t>(seed_offset_data[1]);
// (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t
// with the same size.
rng_state = Empty<int64_t>(ctx, {2});
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
head_size_rounded = round_multiple(head_size, 32);
seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
softmax_d = Empty<float>(ctx, {batch_size, num_heads, seqlen_q_rounded});
dq_accum = Empty<float>(
ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded});
}
};
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_v1_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn_v1.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
bool zero_tensors = false;
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
uint64_t workspace_size;
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_bwd_v1(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
nullptr, // for calculation workspace size
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_bwd_v1(
q.data(),
k.data(),
v.data(),
dq->data(),
dk->data(),
dv->data(),
out.data(),
dout.data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
const_cast<float*>(softmax_lse.data<float>()),
dsoftmax.data(),
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
#endif
}
template <typename T, typename Context>
void FlashAttnV1GradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnV1UnpaddedGradKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
out,
softmax_lse,
seed_offset,
dout,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
dq,
dk,
dv);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_v1_unpadded_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1UnpaddedGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
PD_REGISTER_KERNEL(flash_attn_v1_grad,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1GradKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn_v1.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnV1UnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
ctx.template Alloc<T>(out);
cudaStream_t stream = ctx.stream();
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;
// q,k,v [total_*, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(
dims.size(),
3,
phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
"[total_seq_len, num_heads, head_dim]"));
int64_t total_q = dims[0];
int64_t num_heads = dims[1];
int64_t head_size = dims[2];
int64_t total_k = k.dims()[0];
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
softmax_lse->Resize({batch_size, num_heads, seq_len_q});
ctx.template Alloc<float>(softmax_lse);
if (return_softmax) {
// may allocate more space than *max_seqlen_k*
int64_t blocksize_c = head_size > 64 ? 128 : 256;
int64_t seq_len_k =
((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
if (max_seqlen_k <= 128) {
seq_len_k = 128;
} else if (max_seqlen_k <= 256) {
seq_len_k = 256;
}
softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
ctx.template Alloc<T>(softmax);
}
uint64_t workspace_size;
// TODO(kuizhiqing) pass allocation/empty func in capi to decouple
// calculate workspace size before execution
bool succ = phi::dynload::flash_attn_fwd_v1(
q.data(),
k.data(),
v.data(),
nullptr, // for calculation workspace size
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
DenseTensor workspace;
if (workspace_size > 0) {
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
}
succ = phi::dynload::flash_attn_fwd_v1(
q.data(),
k.data(),
v.data(),
out->data(),
cu_seqlens_q.data(),
cu_seqlens_k.data(),
total_q,
total_k,
batch_size,
num_heads,
head_size,
max_seqlen_q,
max_seqlen_k,
dropout,
scale,
zero_tensors,
causal,
is_bf16,
num_splits,
softmax_lse->data(),
return_softmax ? softmax->data() : nullptr,
workspace_size > 0 ? workspace.data() : nullptr,
&workspace_size,
stream,
seed,
offset);
if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error_v1()));
}
#endif
}
template <typename T, typename Context>
void FlashAttnV1Kernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]
auto dims = q.dims();
PADDLE_ENFORCE_EQ(dims.size(),
4,
phi::errors::InvalidArgument(
"flash_attn receive input with dim "
"[batch_size, seq_len, num_heads, head_dim]"));
int64_t batch_size = dims[0];
int64_t seq_len_q = dims[1];
int64_t num_heads = dims[2];
int64_t head_size = dims[3];
int64_t seq_len_k = k.dims()[1];
int64_t total_q = batch_size * seq_len_q;
int64_t total_k = batch_size * seq_len_k;
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
DenseTensor cu_seqlens_q;
DenseTensor cu_seqlens_k;
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
ArangeNullaryKernel<int32_t, Context>(
ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);
FlashAttnV1UnpaddedKernel<T, Context>(ctx,
q_t_s,
k_t_s,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
scale,
dropout,
causal,
return_softmax,
is_test,
out,
softmax,
softmax_lse,
seed_offset);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(flash_attn_v1_unpadded,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1UnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(flash_attn_v1,
GPU,
ALL_LAYOUT,
phi::FlashAttnV1Kernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -13,6 +13,7 @@ env_dict={
'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@',
'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@',
'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@',
'FLASHATTN_V1_LIBRARIES': '@FLASHATTN_V1_LIBRARIES@',
'LAPACK_LIB':'@LAPACK_LIB@',
'GFORTRAN_LIB':'@GFORTRAN_LIB@',
'GNU_RT_LIB_1':'@GNU_RT_LIB_1@',
......
......@@ -37,3 +37,4 @@ from . import dist_shape
from . import dist_assign
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
......@@ -104,4 +104,5 @@ get_log_level_code = log_util.get_log_level_code
get_log_level_name = log_util.get_log_level_name
save_cache_table = fleet.save_cache_table
perf_test = fleet.perf_test
monitor_perf = fleet.monitor_perf
from .. import auto_parallel as auto
......@@ -382,80 +382,245 @@ class Fleet:
)
return self
def perf_test(self, round=50):
# test allreduce perf
def allreduce_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
paddle.distributed.all_reduce(x, group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
# test reduce perf
def reduce_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
# TODO: shuffle dst
paddle.distributed.reduce(x, dst=min(group.ranks), group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
# test broadcast perf
def broadcast_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
# TODO: shuffle src
paddle.distributed.broadcast(
x, src=min(group.ranks), group=group
)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
# test allreduce perf
def allreduce_test(
self,
iteration,
x,
group,
allreduce_size,
allreduce_thres_time,
warmup=False,
):
if group is None or group.nranks <= 1:
logger.warning("allreduce_test is invalid, group invalid!")
return
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
paddle.distributed.all_reduce(x, group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
ret = (end_t - start_t) / iteration
if warmup:
return
logger.info(
f"[AllReduceTest] nbytes {allreduce_size}B test result: {ret} s/iter"
)
if allreduce_thres_time > -1 and ret > allreduce_thres_time:
logger.warning(
f"[Perf Warnning] AllReduce Test Timeout! {ret} > {allreduce_thres_time}"
)
# test reduce perf
def reduce_test(self, iteration, x, group, reduce_size, reduce_thres_time):
if group is None or group.nranks <= 1:
logger.warning("reduce_test is invalid, group invalid!")
return
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
paddle.distributed.reduce(x, dst=min(group.ranks), group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
ret = (end_t - start_t) / iteration
logger.info(
f"[ReduceTest] nbytes {reduce_size}B test result: {ret} s/iter"
)
if reduce_thres_time > -1 and ret > reduce_thres_time:
logger.warning(
f"[Perf Warnning] Reduce Test Timeout! {ret} > {reduce_thres_time}"
)
# test broadcast perf
def broadcast_test(
self, iteration, x, group, broadcast_size, broadcast_thres_time
):
if group is None or group.nranks <= 1:
logger.warning("broadcast_test is invalid, group invalid!")
return
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
paddle.distributed.broadcast(x, src=min(group.ranks), group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
ret = (end_t - start_t) / iteration
logger.info(
f"[BroadcastTest] nbytes {broadcast_size}B test result: {ret} s/iter"
)
if broadcast_thres_time > -1 and ret > broadcast_thres_time:
logger.warning(
f"[Perf Warnning] Broadcast Test Timeout! {ret} > {broadcast_thres_time}"
)
# test allgather perf
def allgather_test(
self, iteration, x, group, allgather_size, allgather_thres_time
):
if group is None or group.nranks <= 1:
logger.warning("allgather_test is invalid, group invalid!")
return
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
tmp = []
paddle.distributed.all_gather(tmp, x, group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
ret = (end_t - start_t) / iteration
logger.info(
f"[AllgatherTest] nbytes {allgather_size}B test result: {ret} s/iter"
)
if allgather_thres_time > -1 and ret > allgather_thres_time:
logger.warning(
f"[Perf Warnning] Allgather Test Timeout! {ret} > {allgather_thres_time}"
)
# test reduce_scatter perf
def reduce_scatter_test(
self,
iteration,
x,
group,
reduce_scatter_size,
reduce_scatter_thres_time,
):
if group is None or group.nranks <= 1:
logger.warning("reduce_scatter_test is invalid, group invalid!")
return
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
parallelism = group.nranks
output_shape = x.shape
if x.shape[0] % parallelism != 0:
logger.warning(
f"the shape of input[{x.shape[0]}] can't be divided exactly by reduce_scatter parallelism[{parallelism}], test stopped!"
)
return
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=x.dtype)
start_t = time.time()
for _ in range(iteration):
paddle.distributed.stream.reduce_scatter(
output,
x,
op=paddle.distributed.ReduceOp.SUM,
group=group,
sync_op=True,
)
paddle.device.cuda.synchronize()
end_t = time.time()
ret = (end_t - start_t) / iteration
logger.info(
f"[ReduceScatterTest] nbytes {reduce_scatter_size}B test result: {ret} s/iter"
)
if reduce_scatter_thres_time > -1 and ret > reduce_scatter_thres_time:
logger.warning(
f"[Perf Warnning] ReduceScatter Test Timeout! {ret} > {reduce_scatter_thres_time}"
)
def perf_test(self, round=50, test_comm=[], context={}, hcg=None):
if hcg is None:
hcg = self.get_hybrid_communicate_group()
hcg = self.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
sharding_group = hcg.get_sharding_parallel_group()
mp_group = hcg.get_model_parallel_group()
test_group = None
if dp_group.nranks > 1:
test_group = dp_group
elif sharding_group.nranks > 1:
test_group = sharding_group
else:
logger.warning(
f"hcg created with dp_degree: {dp_group.nranks} and sharding_degree: {sharding_group.nranks}, skipping perf test..."
)
return
# test 1M ~ 1G
nbytes = 1 << 20 # 1048576(1MB)
final_nbytes = 1 << 30 # 1073741824(1GB)
dtype = paddle.float32
# run once when test specific package size.
test_specific_size = False
for k, st in context.items():
if st[0] > 0:
test_specific_size = True
break
if test_specific_size:
test_comm = list(context.keys())
if len(test_comm) == 0:
return
while nbytes <= final_nbytes:
x = paddle.zeros([nbytes // 4], dtype=dtype)
# warmup
allreduce_test(iteration=10, x=x, group=test_group)
# test-allreduce
ret = allreduce_test(iteration=round, x=x, group=test_group)
logger.info(
f"[AllReduceTest] nbytes {nbytes}B test result: {ret} s/iter"
self.allreduce_test(10, x, test_group, nbytes, -1, warmup=True)
allreduce_size, allreduce_thres_time = context.get(
"allreduce", [nbytes, -1]
)
reduce_size, reduce_thres_time = context.get("reduce", [nbytes, -1])
broadcast_size, broadcast_thres_time = context.get(
"broadcast", [nbytes, -1]
)
ret = reduce_test(iteration=round, x=x, group=test_group)
logger.info(
f"[ReduceTest] nbytes {nbytes}B test result: {ret} s/iter"
allgather_size, allgather_thres_time = context.get(
"allgather", [nbytes, -1]
)
ret = broadcast_test(iteration=round, x=x, group=test_group)
logger.info(
f"[BroadcastTest] nbytes {nbytes}B test result: {ret} s/iter"
reduce_scatter_size, reduce_scatter_thres_time = context.get(
"reduce_scatter", [nbytes, -1]
)
# inter machines
if "allreduce" in test_comm:
x = paddle.zeros([allreduce_size // 4], dtype=dtype)
self.allreduce_test(
round, x, test_group, allreduce_size, allreduce_thres_time
)
if "reduce" in test_comm:
x = paddle.zeros([reduce_size // 4], dtype=dtype)
self.reduce_test(
round, x, test_group, reduce_size, reduce_thres_time
)
if "broadcast" in test_comm:
x = paddle.zeros([broadcast_size // 4], dtype=dtype)
self.broadcast_test(
round, x, test_group, broadcast_size, broadcast_thres_time
)
# intra machines
if "allgather" in test_comm:
x = paddle.zeros([allgather_size // 4], dtype=dtype)
self.allgather_test(
round, x, mp_group, allgather_size, allgather_thres_time
)
if "reduce_scatter" in test_comm:
x = paddle.zeros([reduce_scatter_size // 4], dtype=dtype)
self.reduce_scatter_test(
round,
x,
mp_group,
reduce_scatter_size,
reduce_scatter_thres_time,
)
# run once when test specific package size.
if test_specific_size:
break
nbytes = nbytes << 1
def monitor_perf(self, comm_type, round=50, size_and_time={}, hcg=None):
for size, time_thres in size_and_time.items():
context = {comm_type: [size, time_thres]}
self.perf_test(round=round, context=context, hcg=hcg)
def _init_hybrid_parallel_env(self):
"""initialize the hybrid environment"""
self.hybrid_configs = self._user_defined_strategy.hybrid_configs
......
......@@ -100,6 +100,15 @@ class DygraphShardingOptimizer:
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
def filter_parameters(self, parameter_list, hcg):
sharding_parallel_rank = hcg.get_sharding_parallel_rank()
parameter_list = [
param
for param in parameter_list
if self._param2rank[param.name] == sharding_parallel_rank
]
return parameter_list
def _partition_parameters(self):
"""
Partitions parameters among sharding ranks.
......
......@@ -293,19 +293,17 @@ class HybridParallelClipGrad:
params_grads, global_norm_var_dist, global_norm_var_not_dist
)
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1
dp_flag = self._hcg.get_data_parallel_world_size() > 1
mp_flag = self._hcg.get_model_parallel_world_size() > 1
# add all reduce to get global norm of distributed params_and_grads
pp_flag = self._hcg.get_pipe_parallel_world_size() > 1
# not g_shard_norm_align_dp, grads are sharded among sharding group
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group and mp group、pp group latter
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
......@@ -315,21 +313,40 @@ class HybridParallelClipGrad:
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# the else branch would suffice, but this branch remains here for number precision backward compatibility
if not (dp_flag and sharding_flag):
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
else:
# global_norm_var_dist should all reduce among model parallel group and pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_model_parallel_group(),
)
if pp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_pipe_parallel_group(),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
if pp_flag:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)
global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist
)
......@@ -554,15 +571,21 @@ class HybridParallelOptimizer:
@no_grad()
@framework.dygraph_only
def step(self):
parameters_list = obtain_optimizer_parameters_list(self._inner_opt)
parameter_list = list(obtain_optimizer_parameters_list(self._inner_opt))
dp_parameter_list = parameter_list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)
fused_allreduce_gradients(dp_parameter_list, self._hcg)
self._step(parameters_list)
self._step(parameter_list)
@no_grad()
def minimize(
......@@ -574,14 +597,20 @@ class HybridParallelOptimizer:
parameter_list = (
parameters if parameters else self._inner_opt._parameter_list
)
parameter_list = list(parameter_list)
dp_parameter_list = parameter_list
# Here sharding should use global parameter list
if self._sharding_enable:
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
# dp sync later do not need to use global parameter list
if not g_shard_norm_align_dp:
dp_parameter_list = self._inner_opt.filter_parameters(
parameter_list, self._hcg
)
if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg)
fused_allreduce_gradients(dp_parameter_list, self._hcg)
return self._inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set
......
......@@ -777,9 +777,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap:
self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.accumulate_steps == 0:
if sync_step > 0 and sync_step % self.num_stages == 0:
chunk_idx = self._virtual_pp_world_size - (
sync_step // self.accumulate_steps
sync_step // self.num_stages
)
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads()
......@@ -787,7 +787,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self.stage_id != 0:
if (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
== self.num_stages * self.num_model_chunks
):
for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads()
......@@ -796,11 +796,10 @@ class PipelineParallelWithInterleave(PipelineParallel):
if self._comm_overlap:
assert (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
), "backward step count should be equal to accumulate steps * "
"virtual pp world size, but get {}, excepted result is {}".format(
self._backward_step_count,
self.accumulate_steps * self._virtual_pp_world_size,
== self.num_stages * self.num_model_chunks
), (
"backward step count should be equal to accumulate steps * virtual pp world size,"
f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}"
)
for _, buffers in self._chunk_2_comm_buffers.items():
......@@ -863,7 +862,18 @@ class PipelineParallelWithInterleave(PipelineParallel):
self._forward_only = forward_only
# store the number of backward steps
self._backward_step_count = 0
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
)
per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
self._backward_step_count = (
-(per_stage_accumulate_steps - 1)
* self.num_stages
* self.num_model_chunks
)
# init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)]
......
......@@ -55,16 +55,18 @@ class MixPrecisionLayer(nn.Layer):
), "In main_grad node, param.grad should be None, but find param[{}] has grad.".format(
param.name
)
if param.main_grad is None:
param.main_grad = core.eager.Tensor(
value=tmp_grad.cast(paddle.float32).value(),
place=tmp_grad.place,
name="main_grad@" + param.name,
)
else:
param.main_grad.add_(tmp_grad)
if tmp_grad._is_initialized():
# Some previous pylayer may return None, should check grad validation.
if param.main_grad is None:
param.main_grad = core.eager.Tensor(
value=tmp_grad.cast(paddle.float32).value(),
place=tmp_grad.place,
name="main_grad@" + param.name,
)
else:
param.main_grad.add_(tmp_grad)
tmp_grad._clear_data()
tmp_grad._clear_data()
return None
return param_hook
......
......@@ -60,6 +60,106 @@ class TestDistDPTraning(unittest.TestCase):
def test_communication_perf(self):
fleet.perf_test(round=1)
# test comm type in test_comm(list), scan package from 1M to 1G
fleet.perf_test(
round=1,
test_comm=[
"allreduce",
"reduce",
"broadcast",
"allgather",
"reduce_scatter",
],
)
# context: {comm_type:[size, time]}
# only test allreduce for package(1024B) and time threshold(0.00000001s),
# and test allgather for package(8192B) and time threshold(2s),
fleet.perf_test(
round=30,
test_comm=[
"allreduce",
"reduce",
"broadcast",
"allgather",
"reduce_scatter",
],
context={
"allreduce": [1024, 0.00000001],
"reduce": [1024, 0.00000001],
"broadcast": [1024, 0.00000001],
"allgather": [8192, 2],
},
)
# test allreduce for specific size and time.
fleet.monitor_perf(
"allreduce",
round=50,
size_and_time={1024: 0.00000001, 4096: 0.01, 8192: 2},
)
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
from paddle.distributed.fleet.base.topology import (
CommunicateTopology,
HybridCommunicateGroup,
)
topo = CommunicateTopology(
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[1, 1, 1, 2],
)
self.hcg = HybridCommunicateGroup(topo)
def build_optimizer(self, model):
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
return scheduler, optimizer
def test_communication_perf(self):
# test comm type in test_comm(list), scan package from 1M to 1G
fleet.perf_test(
round=1,
test_comm=["allreduce", "allgather", "reduce_scatter"],
hcg=self.hcg,
)
# context: {comm_type:[size, time]}
# only test reduce for package(1024B) and time threshold(1s),
# and test allgather for package(8192B) and time threshold(0.00000002s),
fleet.perf_test(
round=100000,
context={
"reduce": [1024, 1],
"allgather": [8192, 0.00000002],
"reduce_scatter": [8192, 0.00000002],
},
hcg=self.hcg,
)
# test allgather for specific size and time.
fleet.monitor_perf(
"allgather",
round=50,
size_and_time={1024: 1, 4096: 0.01, 8192: 0.00000002},
hcg=self.hcg,
)
if __name__ == "__main__":
......
......@@ -56,9 +56,27 @@ def attention_naive(q, k, v, causal=False):
return paddle.transpose(o, [0, 2, 1, 3])
is_sm8x = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 8
and paddle.device.cuda.get_device_capability()[1] >= 0
)
is_sm90 = (
core.is_compiled_with_cuda()
and paddle.device.cuda.get_device_capability()[0] == 9
and paddle.device.cuda.get_device_capability()[1] == 0
)
is_sm_supported = is_sm8x or is_sm90
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
not core.is_compiled_with_cuda()
or get_cuda_version() < 11040
or not is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 8.x or 90",
)
class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
......
......@@ -19,6 +19,7 @@ from .fused_matmul_bias import fused_matmul_bias, fused_linear
from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_rotary_position_embedding import fused_rotary_position_embedding
__all__ = [
......@@ -30,4 +31,5 @@ __all__ = [
'fused_bias_dropout_residual_layer_norm',
'fused_ec_moe',
'fused_dropout_add',
'fused_rotary_position_embedding',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.framework import in_dynamic_mode
def fused_rotary_position_embedding(q, k=None, v=None):
r"""
Fused rotary position embedding.
Args:
q (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if q must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
k (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if k must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
v (potional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
Returns:
out_q/out_k/out_v Tensor representing the fused rotary position embedding, has same shape and data type as `q` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_rotary_position_embedding
q = paddle.randn([1, 1, 4, 10], dtype='float16')
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
"""
if in_dynamic_mode():
return _C_ops.fused_rotary_position_embedding(q, k, v)
raise RuntimeError(
"This feature is currently supported only in dynamic mode and with CUDAPlace."
)
......@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
g_use_flash_attn_v1 = (
os.getenv('FLAGS_flash_attn_version', 'v2').strip().lower() == 'v1'
)
def flash_attention(
query,
......@@ -24,6 +30,9 @@ def flash_attention(
dropout=0.0,
causal=False,
return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -57,7 +66,9 @@ def flash_attention(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -80,15 +91,29 @@ def flash_attention(
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
dropout,
causal,
return_softmax,
not training,
)
if g_use_flash_attn_v1:
(result_attention, result_softmax, _, _) = _C_ops.flash_attn_v1(
query,
key,
value,
dropout,
causal,
return_softmax,
not training,
)
else:
(result_attention, result_softmax, _, _) = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals())
......@@ -101,6 +126,7 @@ def flash_attention(
'q': query,
'k': key,
'v': value,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -117,6 +143,7 @@ def flash_attention(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
......@@ -134,6 +161,8 @@ def flash_attn_unpadded(
dropout=0.0,
causal=False,
return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -174,6 +203,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
......@@ -197,20 +228,38 @@ def flash_attn_unpadded(
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
)
if g_use_flash_attn_v1:
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
)
else:
(result_attention, result_softmax,) = _C_ops.flash_attn_unpadded(
query,
key,
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn_unpadded', **locals())
......@@ -225,6 +274,7 @@ def flash_attn_unpadded(
'v': value,
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -244,6 +294,7 @@ def flash_attn_unpadded(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
......@@ -599,6 +599,10 @@ if len('${FLASHATTN_LIBRARIES}') > 1:
package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_LIBRARIES}')]
shutil.copy('${FLASHATTN_LIBRARIES}', libs_path)
if len('${FLASHATTN_V1_LIBRARIES}') > 1:
package_data['paddle.libs']+=[os.path.basename('${FLASHATTN_V1_LIBRARIES}')]
shutil.copy('${FLASHATTN_V1_LIBRARIES}', libs_path)
if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_SHARED_LIB}', libs_path)
shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)
......
......@@ -1036,6 +1036,13 @@ def get_package_data_and_package_dir():
os.path.basename(env_dict.get("FLASHATTN_LIBRARIES"))
]
shutil.copy(env_dict.get("FLASHATTN_LIBRARIES"), libs_path)
if len(env_dict.get("FLASHATTN_V1_LIBRARIES", "")) > 1:
package_data['paddle.libs'] += [
os.path.basename(env_dict.get("FLASHATTN_V1_LIBRARIES"))
]
shutil.copy(env_dict.get("FLASHATTN_V1_LIBRARIES"), libs_path)
if env_dict.get("WITH_LITE") == 'ON':
shutil.copy(env_dict.get("LITE_SHARED_LIB"), libs_path)
package_data['paddle.libs'] += [
......
......@@ -75,9 +75,9 @@ set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
if(NOT WIN32)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 180)
set_tests_properties(test_tsm PROPERTIES TIMEOUT 900)
#set_tests_properties(test_resnet PROPERTIES TIMEOUT 120)
set_tests_properties(test_resnet PROPERTIES TIMEOUT 240)
endif()
if(APPLE)
......
......@@ -426,20 +426,6 @@ class TestResnet(unittest.TestCase):
)
self.verify_predict()
def test_resnet_composite_backward(self):
core._set_prim_backward_enabled(True)
static_loss = self.train(to_static=True)
core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_resnet_composite_forward_backward(self):
core._set_prim_all_enabled(True)
static_loss = self.train(to_static=True)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.incubate.nn.functional import fused_rotary_position_embedding
def deal_qkv(init_q, init_k, init_v):
perm = [0, 2, 1, 3]
q = paddle.transpose(x=init_q, perm=perm)
k = paddle.transpose(x=init_k, perm=perm)
v = paddle.transpose(x=init_v, perm=perm)
return q, k, v
def mult_qkv(value, cos_tensor, sin_tensor):
rotate_half_q = paddle.reshape(
paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1),
paddle.shape(value),
)
query = paddle.add(
paddle.multiply(value, cos_tensor),
paddle.multiply(rotate_half_q, sin_tensor),
)
return query
def paddle_fused_rotary_position_embedding(init_q, init_k, init_v):
q, k, v = deal_qkv(init_q, init_k, init_v)
pos_seq = paddle.arange(0, q.shape[2], 1, dtype="float32")
indices = paddle.arange(0, q.shape[3], 2, dtype="float32")
indices = 1 / 10000 ** (indices / q.shape[3])
sinusoid_inp = pos_seq.unsqueeze(1) * indices.unsqueeze(0)
sin_sin = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
cos_cos = np.empty((q.shape[2] * q.shape[3]), dtype=np.float32)
numpy_array = sinusoid_inp.numpy()
iter_array = np.nditer(numpy_array)
i = 0
for value in iter_array:
sin_sin[i * 2] = -1 * np.sin(value)
cos_cos[i * 2 + 0] = np.cos(value)
sin_sin[i * 2 + 1] = np.sin(value)
cos_cos[i * 2 + 1] = np.cos(value)
i += 1
sin_tensor = paddle.reshape(
paddle.to_tensor(sin_sin, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
)
cos_tensor = paddle.reshape(
paddle.to_tensor(cos_cos, place=paddle.CPUPlace()),
[1, 1, q.shape[2], q.shape[3]],
)
query = mult_qkv(q, cos_tensor, sin_tensor)
value = mult_qkv(v, cos_tensor, sin_tensor)
key = mult_qkv(k, cos_tensor, sin_tensor)
r_query, r_key, r_value = deal_qkv(query, key, value)
return r_query, r_key, r_value
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not compiled with CUDA ",
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self):
self.shape = [1, 16, 1, 16]
self.dtype = 'float32'
self.training = True
self.seed = 1203
def get_paddle_tensor(self):
tmp = paddle.randn(self.shape, self.dtype)
tmp.stop_gradient = False
return tmp
def get_forward_backward(self, rope_function, seed):
paddle.disable_static()
paddle.seed(seed)
fw = []
bw = []
tensor_q = self.get_paddle_tensor()
tensor_k = self.get_paddle_tensor()
tensor_v = self.get_paddle_tensor()
out_q, out_k, out_v = rope_function(tensor_q, tensor_k, tensor_v)
fw.append(out_q)
fw.append(out_k)
fw.append(out_v)
out_gq = paddle.randn(out_q.shape, self.dtype)
out_gk = paddle.randn(out_q.shape, self.dtype)
out_gv = paddle.randn(out_q.shape, self.dtype)
paddle.autograd.backward(
[out_q, out_k, out_v], [out_gq, out_gk, out_gv], True
)
bw.append(tensor_q)
bw.append(tensor_k)
bw.append(tensor_v)
return fw, bw
def test_fused_dropout_add(self):
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding, seed=self.seed
)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding, seed=self.seed
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def test_error(self):
paddle.enable_static()
with self.assertRaises(RuntimeError):
static_q = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
)
fused_rotary_position_embedding(static_q, static_q, static_q)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()