未验证 提交 1d23e0bb 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick]add flash randomness control and add scaled_dot_product_attention (#53518)

att, cherry-pick: #52902 #53113
上级 39b704c1
...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) ...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc) set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_INCLUDE_DIR set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include" "${FLASHATTN_INSTALL_DIR}/include"
......
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad - backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
...@@ -628,7 +628,7 @@ ...@@ -628,7 +628,7 @@
data_type: q data_type: q
- backward_op : flash_attn_unpadded_grad - backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
......
...@@ -678,8 +678,9 @@ ...@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn - op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
...@@ -690,8 +691,9 @@ ...@@ -690,8 +691,9 @@
backward : flash_attn_grad backward : flash_attn_grad
- op : flash_attn_unpadded - op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
......
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
float scale, float scale,
...@@ -33,6 +35,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -33,6 +35,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -43,10 +46,12 @@ void FlashAttnKernel(const Context& ctx, ...@@ -43,10 +46,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
...@@ -25,6 +27,8 @@ ...@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#endif #endif
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1; int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false; bool zero_tensors = false;
const int64_t* seed_offset_data = seed_offset.data<int64_t>(); const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]); uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]); uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q}); DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
...@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h"
...@@ -28,15 +29,19 @@ ...@@ -28,15 +29,19 @@
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#endif #endif
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q, int64_t max_seqlen_q,
int64_t max_seqlen_k, int64_t max_seqlen_k,
float scale, float scale,
...@@ -44,6 +49,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -44,6 +49,7 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1; int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false; bool zero_tensors = false;
auto gen = ctx.GetGenerator(); uint64_t seed;
uint64_t offset;
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32; uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
uint64_t seed = seed_offset_pair.first; VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
uint64_t offset = seed_offset_pair.second; << ", num_splits:" << num_splits;
seed_offset->Resize({2}); seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset); int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed); seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset); seed_offset_data[1] = static_cast<int64_t>(offset);
...@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx, ...@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx, ...@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
...@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s, v_t_s,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset,
seq_len_q, seq_len_q,
seq_len_k, seq_len_k,
scale, scale,
...@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal, causal,
return_softmax, return_softmax,
is_test, is_test,
rng_name,
out, out,
softmax, softmax,
softmax_lse, softmax_lse,
...@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, ...@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel, phi::FlashAttnUnpaddedKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn, PD_REGISTER_KERNEL(flash_attn,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnKernel, phi::FlashAttnKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
...@@ -38,3 +38,4 @@ from . import dist_shape ...@@ -38,3 +38,4 @@ from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
...@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
def test_unpadded(self): def test_unpadded(self):
print( print(
...@@ -189,6 +190,16 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -189,6 +190,16 @@ class TestFlashAttentionAPI(unittest.TestCase):
value, place=self.place, dtype=self.dtype, stop_gradient=False value, place=self.place, dtype=self.dtype, stop_gradient=False
) )
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
else:
out, _ = flash_attention( out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax q, k, v, self.dropout, self.causal, self.return_softmax
) )
...@@ -220,6 +231,21 @@ class TestFlashAttentionAPI(unittest.TestCase): ...@@ -220,6 +231,21 @@ class TestFlashAttentionAPI(unittest.TestCase):
name="v", shape=self.shape, dtype=self.dtype name="v", shape=self.shape, dtype=self.dtype
) )
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention( outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax qs, ks, vs, self.dropout, self.causal, self.return_softmax
) )
...@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI): ...@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest2(TestFlashAttentionAPI): class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
...@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI): ...@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = True self.return_softmax = True
self.use_sdp_kernel = False
class TestFlashAttentionAPITest3(TestFlashAttentionAPI): class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
...@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI): ...@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = True self.causal = True
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest4(TestFlashAttentionAPI): class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
...@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI): ...@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self.dropout = 0.0 self.dropout = 0.0
self.causal = False self.causal = False
self.return_softmax = False self.return_softmax = False
self.use_sdp_kernel = False
class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401 ...@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401 from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention from .sparse_attention import sparse_attention
from .flash_attention import scaled_dot_product_attention
from .flash_attention import sdp_kernel
__all__ = [ # noqa __all__ = [ # noqa
'celu', 'celu',
......
...@@ -13,8 +13,113 @@ ...@@ -13,8 +13,113 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
g_enable_math = None
g_enable_flash = None
g_enable_mem_efficient = None
@signature_safe_contextmanager
def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
r"""
With the sdp_kernel context manager, different algorithm implementations can
be selected for scaled_dot_product_attention.
"""
global g_enable_math, g_enable_flash, g_enable_mem_efficient
original_enable_math = g_enable_math
original_enable_flash = g_enable_math
original_enable_mem_efficient = g_enable_mem_efficient
g_enable_math = enable_math
g_enable_flash = enable_flash
g_enable_mem_efficient = enable_mem_efficient
try:
yield
finally:
g_enable_math = original_enable_math
g_enable_flash = original_enable_flash
g_enable_mem_efficient = original_enable_mem_efficient
def _math_attention(
query,
key,
value,
dropout_rate=0.0,
causal=False,
return_softmax=False,
training=True,
):
r"""
This is a basic implementation of scaled dot product attention composed of
combinations of fundamental components.
"""
head_dim = query.shape[-1]
query = paddle.transpose(query, [0, 2, 1, 3])
key = paddle.transpose(key, [0, 2, 1, 3])
value = paddle.transpose(value, [0, 2, 1, 3])
product = paddle.matmul(
x=query * (head_dim**-0.5), y=key, transpose_y=True
)
weights = (
paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if causal
else F.softmax(product)
)
if dropout_rate > 0.0:
weights = F.dropout(
weights, dropout_rate, training=training, mode="upscale_in_train"
)
out = paddle.matmul(weights, value)
out = paddle.transpose(out, [0, 2, 1, 3])
return out, weights if return_softmax else None
def _select_sdp_cuda(head_dim):
if head_dim < 128:
return "flash_attn"
else:
return "mem_efficient"
def _select_sdp(head_dim):
r"""
There are currently three different implementation options available for
scaled dot product attention, and the chosen approach depends on whether it
is determined by the sdp_kernel configuration or specified through input values.
"""
place = paddle.get_device()
# not use sdp_kernel
if g_enable_flash is None:
if "gpu" not in place:
return "math"
else:
return _select_sdp_cuda(head_dim)
if (
g_enable_math is False
and g_enable_flash is False
and g_enable_mem_efficient is False
):
raise AssertionError(
"No available backend for scaled_dot_product_attention was found."
)
if g_enable_math is True:
if g_enable_flash is False and g_enable_mem_efficient is False:
return "math"
if "gpu" not in place:
return "math"
if g_enable_flash is True and g_enable_mem_efficient is True:
return _select_sdp_cuda(head_dim)
if g_enable_flash is True:
return "flash_attn"
return "mem_efficient"
def flash_attention( def flash_attention(
...@@ -24,6 +129,9 @@ def flash_attention( ...@@ -24,6 +129,9 @@ def flash_attention(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -57,7 +165,9 @@ def flash_attention( ...@@ -57,7 +165,9 @@ def flash_attention(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -79,15 +189,21 @@ def flash_attention( ...@@ -79,15 +189,21 @@ def flash_attention(
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False) output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output) print(output)
""" """
head_dim = query.shape[3]
sdp_func_name = _select_sdp(head_dim)
if sdp_func_name == "flash_attn":
if in_dynamic_mode(): if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn( (result_attention, result_softmax,) = _C_ops.flash_attn(
query, query,
key, key,
value, value,
fixed_seed_offset,
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -101,6 +217,7 @@ def flash_attention( ...@@ -101,6 +217,7 @@ def flash_attention(
'q': query, 'q': query,
'k': key, 'k': key,
'v': value, 'v': value,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -117,9 +234,36 @@ def flash_attention( ...@@ -117,9 +234,36 @@ def flash_attention(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
else:
if sdp_func_name == "mem_efficient":
from paddle.incubate.nn.memory_efficient_attention import (
memory_efficient_attention,
)
output = memory_efficient_attention(
query,
key,
value,
attn_bias=None,
p=dropout,
scale=None,
training=training,
)
return output, None
else:
return _math_attention(
query,
key,
value,
dropout_rate=dropout,
causal=causal,
return_softmax=return_softmax,
training=training,
)
def flash_attn_unpadded( def flash_attn_unpadded(
...@@ -134,6 +278,8 @@ def flash_attn_unpadded( ...@@ -134,6 +278,8 @@ def flash_attn_unpadded(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -174,6 +320,8 @@ def flash_attn_unpadded( ...@@ -174,6 +320,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
...@@ -203,6 +351,7 @@ def flash_attn_unpadded( ...@@ -203,6 +351,7 @@ def flash_attn_unpadded(
value, value,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
scale, scale,
...@@ -210,6 +359,7 @@ def flash_attn_unpadded( ...@@ -210,6 +359,7 @@ def flash_attn_unpadded(
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -225,6 +375,7 @@ def flash_attn_unpadded( ...@@ -225,6 +375,7 @@ def flash_attn_unpadded(
'v': value, 'v': value,
'cu_seqlens_q': cu_seqlens_q, 'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k, 'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -244,6 +395,10 @@ def flash_attn_unpadded( ...@@ -244,6 +395,10 @@ def flash_attn_unpadded(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
scaled_dot_product_attention = flash_attention
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册