diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index eae35d90f50f097f9f63075bc919d11cb093c861..95893ad27a6a1b30c465ec1676ce23c205c461ec 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) -set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc) +set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) set(FLASHATTN_INCLUDE_DIR "${FLASHATTN_INSTALL_DIR}/include" diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 0a3384d13fc684030c8c1862cab6aaa248c576cb..2394182ee4bd1ccaf1be30af5d40d0066132f6de 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -617,7 +617,7 @@ inplace : (out_grad -> x_grad) - backward_op : flash_attn_grad - forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : @@ -628,7 +628,7 @@ data_type: q - backward_op : flash_attn_unpadded_grad - forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 3c344060b9bcbf879c2d84f52c583e38c04128fb..fe743f9609136018ae194290881b5b2d2b6a7399 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -678,8 +678,9 @@ backward : fill_diagonal_tensor_grad - op : flash_attn - args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) + args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset infer_meta : func : FlashAttnInferMeta param : [q, k, v] @@ -690,8 +691,9 @@ backward : flash_attn_grad - op : flash_attn_unpadded - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset infer_meta : func : FlashAttnInferMeta param : [q, k, v] diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index 54c77a9dea7397d751e2e8be0018e96fcd9132b9..296e24202608753c0eedd7ab6715922e74a974d7 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -20,33 +20,38 @@ namespace phi { template -void FlashAttnUnpaddedKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - bool return_softmax, - bool is_test, - DenseTensor* out, - DenseTensor* softmax, - DenseTensor* softmax_lse, - DenseTensor* seed_offset); +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset); template void FlashAttnKernel(const Context& ctx, const DenseTensor& q, const DenseTensor& k, const DenseTensor& v, + const paddle::optional& fixed_seed_offset, float dropout, bool causal, bool return_softmax, bool is_test, + const std::string& rng_name, DenseTensor* out, DenseTensor* softmax, DenseTensor* softmax_lse, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 8e75ecc473f2cb758afec60e183233367284d41f..b75f4b4aea4b88cc790d64ba6389d3677a37b38b 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" @@ -25,6 +27,8 @@ #include "paddle/phi/backends/dynload/flashattn.h" #endif +DECLARE_bool(cudnn_deterministic); + namespace phi { template @@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, int64_t batch_size = cu_seqlens_q.numel() - 1; int num_splits = 0; // 0 for an internal heuristic, which is optimal + if (FLAGS_cudnn_deterministic) { + num_splits = 1; + } bool zero_tensors = false; const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); + VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset + << ", num_splits:" << num_splits; + int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); @@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx, float scale = 1.0f / std::sqrt(head_size); + VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + DenseTensor q_t_s, k_t_s, v_t_s; q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 7c2cd423dd03283b3431565151aa4ea331941f5f..714edf4be6f3c5313c81ee0581573326de6be6c7 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -14,12 +14,13 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" +#include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" - #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" @@ -28,26 +29,31 @@ #include "paddle/phi/backends/dynload/flashattn.h" #endif +DECLARE_bool(cudnn_deterministic); + namespace phi { template -void FlashAttnUnpaddedKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - bool return_softmax, - bool is_test, - DenseTensor* out, - DenseTensor* softmax, - DenseTensor* softmax_lse, - DenseTensor* seed_offset) { +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN if (is_test) dropout = 0.0f; @@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx, int64_t batch_size = cu_seqlens_q.numel() - 1; int num_splits = 0; // 0 for an internal heuristic, which is optimal + if (FLAGS_cudnn_deterministic) { + num_splits = 1; + } bool zero_tensors = false; - auto gen = ctx.GetGenerator(); - uint64_t inc = batch_size * num_heads * 32; - auto seed_offset_pair = gen->IncrementOffset(inc); + uint64_t seed; + uint64_t offset; + + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + seed = static_cast(fixed_seed_offset_data[0]); + offset = static_cast(fixed_seed_offset_data[1]); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + if (rng_name != "") { + auto gen = phi::GetRandomSeedGenerator(rng_name); + seed_offset_pair = gen->IncrementOffset(inc); + } else { + auto* gen = ctx.GetGenerator(); + seed_offset_pair = gen->IncrementOffset(inc); + } + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; + } - uint64_t seed = seed_offset_pair.first; - uint64_t offset = seed_offset_pair.second; + VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset + << ", num_splits:" << num_splits; seed_offset->Resize({2}); - auto* seed_offset_data = ctx.template HostAlloc(seed_offset); + int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); seed_offset_data[0] = static_cast(seed); seed_offset_data[1] = static_cast(offset); @@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& q, const DenseTensor& k, const DenseTensor& v, + const paddle::optional& fixed_seed_offset, float dropout, bool causal, bool return_softmax, bool is_test, + const std::string& rng_name, DenseTensor* out, DenseTensor* softmax, DenseTensor* softmax_lse, @@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx, float scale = 1.0f / std::sqrt(head_size); + VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + DenseTensor q_t_s, k_t_s, v_t_s; q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); @@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx, v_t_s, cu_seqlens_q, cu_seqlens_k, + fixed_seed_offset, seq_len_q, seq_len_k, scale, @@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx, causal, return_softmax, is_test, + rng_name, out, softmax, softmax_lse, @@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, ALL_LAYOUT, phi::FlashAttnUnpaddedKernel, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16) { + kernel->InputAt(5).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} PD_REGISTER_KERNEL(flash_attn, GPU, ALL_LAYOUT, phi::FlashAttnKernel, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 862a51fc41a9e065924ceb5eac88fc3e09c89116..8efb6cf068569cb6c9f32742b56fc403209ba4fd 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -38,3 +38,4 @@ from . import dist_shape from . import dist_assign from . import dist_scale from . import dist_dropout +from . import dist_flash_attn diff --git a/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..331bdfd25ae0aba86e6ffbca7bde4683588ee75c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_flash_attn.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import logging + +from ...utils.log_utils import get_logger + +_logger = get_logger(logging.INFO) +from ..random import determinate_rng, is_enable_auto_rand_ctrl +from .common import ( + DistributedOperatorImplContainer, + register_distributed_operator_impl, + register_distributed_operator_impl_container, +) +from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 + + +class DistributedFlashAttn(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + +register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn")) + + +# Dist FlashAttn with Random Control +class DistributedFlashAttnImpl0(DistributedElementwiseImpl0): + def __init__(self, name): + super().__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + return True + + def is_output_compatible(self, dist_op): + return True + + def is_auto_compatible(self, dist_op): + return True + + @staticmethod + def forward(ctx, *args, **kwargs): + + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.work_block + startup_block = dist_op_context.startup_block + src_op = dist_op_context.cur_src_op + rank_id = dist_op_context.rank_id + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + + if ( + is_enable_auto_rand_ctrl() + and not op_dist_attr.is_recompute + and rank_id in op_dist_attr.process_mesh.process_ids + ): + + assert ( + op_dist_attr is not None + ), f"forward op [{str(src_op)}] don't have dist attribute !" + + if ( + len(kwargs.get('fixed_seed_offset', [])) > 0 + or len(src_op.input("fixed_seed_offset")) > 0 + ): + # TODO(kuizhiqing) recompute should go here + pass + else: + # determinate rng + q_var = main_block._var_recursive(kwargs['q'][0]) + k_var = main_block._var_recursive(kwargs['k'][0]) + q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name) + k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name) + process_mesh = op_dist_attr.process_mesh + dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]] + + rng_name = determinate_rng(rank_id, dims_mapping, process_mesh) + assert rng_name is not None and rng_name != "" + + src_op._set_attr('rng_name', rng_name) + + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + # dropout backward is deterministic by mask, and not need for random state control + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "flash_attn", DistributedFlashAttnImpl0("random_control") +) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 0911888799fc456bfda88f0d8f868b9eb57a1ca1..78c2dd6e7618af285d5188dc1989f87644cdb160 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -24,6 +24,9 @@ def flash_attention( dropout=0.0, causal=False, return_softmax=False, + *, + fixed_seed_offset=None, + rng_name="", training=True, name=None, ): @@ -57,7 +60,9 @@ def flash_attention( dropout(float): The dropout ratio. causal(bool): Whether enable causal mode. return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. training(bool): Whether it is in the training phase. + rng_name(str): The name to select Generator. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -84,10 +89,12 @@ def flash_attention( query, key, value, + fixed_seed_offset, dropout, causal, return_softmax, not training, + rng_name, ) return result_attention, result_softmax if return_softmax else None @@ -101,6 +108,7 @@ def flash_attention( 'q': query, 'k': key, 'v': value, + 'fixed_seed_offset': fixed_seed_offset, } outputs = { 'out': out, @@ -117,6 +125,7 @@ def flash_attention( 'causal': causal, 'return_softmax': return_softmax, 'is_test': not training, + 'rng_name': rng_name, }, ) return out, softmax if return_softmax else None @@ -134,6 +143,8 @@ def flash_attn_unpadded( dropout=0.0, causal=False, return_softmax=False, + fixed_seed_offset=None, + rng_name="", training=True, name=None, ): @@ -174,6 +185,8 @@ def flash_attn_unpadded( dropout(float): The dropout ratio. causal(bool): Whether enable causal mode. return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + rng_name(str): The name to select Generator. training(bool): Whether it is in the training phase. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to @@ -203,6 +216,7 @@ def flash_attn_unpadded( value, cu_seqlens_q, cu_seqlens_k, + fixed_seed_offset, max_seqlen_q, max_seqlen_k, scale, @@ -210,6 +224,7 @@ def flash_attn_unpadded( causal, return_softmax, not training, + rng_name, ) return result_attention, result_softmax if return_softmax else None @@ -225,6 +240,7 @@ def flash_attn_unpadded( 'v': value, 'cu_seqlens_q': cu_seqlens_q, 'cu_seqlens_k': cu_seqlens_k, + 'fixed_seed_offset': fixed_seed_offset, } outputs = { 'out': out, @@ -244,6 +260,7 @@ def flash_attn_unpadded( 'causal': causal, 'return_softmax': return_softmax, 'is_test': not training, + 'rng_name': rng_name, }, ) return out, softmax if return_softmax else None