未验证 提交 1d23e0bb 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick]add flash randomness control and add scaled_dot_product_attention (#53518)

att, cherry-pick: #52902 #53113
上级 39b704c1
......@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
......
......@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -628,7 +628,7 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......
......@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -690,8 +691,9 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......
......@@ -20,33 +20,38 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
......@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
......@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
......@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......
......@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
......@@ -28,26 +29,31 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
......@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed;
uint64_t offset;
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
......@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
seq_len_q,
seq_len_k,
scale,
......@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal,
return_softmax,
is_test,
rng_name,
out,
softmax,
softmax_lse,
......@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
phi::FlashAttnKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
......@@ -38,3 +38,4 @@ from . import dist_shape
from . import dist_assign
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
......@@ -68,6 +68,7 @@ class TestFlashAttentionAPI(unittest.TestCase):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
def test_unpadded(self):
print(
......@@ -189,9 +190,19 @@ class TestFlashAttentionAPI(unittest.TestCase):
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
else:
out, _ = flash_attention(
q, k, v, self.dropout, self.causal, self.return_softmax
)
out_ = attention_naive(q_, k_, v_, self.causal)
out.backward()
......@@ -220,9 +231,24 @@ class TestFlashAttentionAPI(unittest.TestCase):
name="v", shape=self.shape, dtype=self.dtype
)
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
)
if self.use_sdp_kernel:
with paddle.nn.functional.sdp_kernel(
enable_math=self.enable_math,
enable_flash=self.enable_flash,
enable_mem_efficient=self.enable_mem_efficient,
):
outs, softmax = flash_attention(
qs,
ks,
vs,
self.dropout,
self.causal,
self.return_softmax,
)
else:
outs, softmax = flash_attention(
qs, ks, vs, self.dropout, self.causal, self.return_softmax
)
exe = fluid.Executor(self.place)
fetches_result = exe.run(
......@@ -247,6 +273,7 @@ class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
......@@ -257,6 +284,7 @@ class TestFlashAttentionAPITest2(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = True
self.use_sdp_kernel = False
class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
......@@ -267,6 +295,7 @@ class TestFlashAttentionAPITest3(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = True
self.return_softmax = False
self.use_sdp_kernel = False
class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
......@@ -277,6 +306,21 @@ class TestFlashAttentionAPITest4(TestFlashAttentionAPI):
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = False
class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (8, 1024, 16, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.causal = False
self.return_softmax = False
self.use_sdp_kernel = True
self.enable_math = True
self.enable_flash = False
self.enable_mem_efficient = False
if __name__ == '__main__':
......
......@@ -134,6 +134,8 @@ from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
from .flash_attention import scaled_dot_product_attention
from .flash_attention import sdp_kernel
__all__ = [ # noqa
'celu',
......
......@@ -13,8 +13,113 @@
# limitations under the License.
import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
g_enable_math = None
g_enable_flash = None
g_enable_mem_efficient = None
@signature_safe_contextmanager
def sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=True):
r"""
With the sdp_kernel context manager, different algorithm implementations can
be selected for scaled_dot_product_attention.
"""
global g_enable_math, g_enable_flash, g_enable_mem_efficient
original_enable_math = g_enable_math
original_enable_flash = g_enable_math
original_enable_mem_efficient = g_enable_mem_efficient
g_enable_math = enable_math
g_enable_flash = enable_flash
g_enable_mem_efficient = enable_mem_efficient
try:
yield
finally:
g_enable_math = original_enable_math
g_enable_flash = original_enable_flash
g_enable_mem_efficient = original_enable_mem_efficient
def _math_attention(
query,
key,
value,
dropout_rate=0.0,
causal=False,
return_softmax=False,
training=True,
):
r"""
This is a basic implementation of scaled dot product attention composed of
combinations of fundamental components.
"""
head_dim = query.shape[-1]
query = paddle.transpose(query, [0, 2, 1, 3])
key = paddle.transpose(key, [0, 2, 1, 3])
value = paddle.transpose(value, [0, 2, 1, 3])
product = paddle.matmul(
x=query * (head_dim**-0.5), y=key, transpose_y=True
)
weights = (
paddle.incubate.softmax_mask_fuse_upper_triangle(product)
if causal
else F.softmax(product)
)
if dropout_rate > 0.0:
weights = F.dropout(
weights, dropout_rate, training=training, mode="upscale_in_train"
)
out = paddle.matmul(weights, value)
out = paddle.transpose(out, [0, 2, 1, 3])
return out, weights if return_softmax else None
def _select_sdp_cuda(head_dim):
if head_dim < 128:
return "flash_attn"
else:
return "mem_efficient"
def _select_sdp(head_dim):
r"""
There are currently three different implementation options available for
scaled dot product attention, and the chosen approach depends on whether it
is determined by the sdp_kernel configuration or specified through input values.
"""
place = paddle.get_device()
# not use sdp_kernel
if g_enable_flash is None:
if "gpu" not in place:
return "math"
else:
return _select_sdp_cuda(head_dim)
if (
g_enable_math is False
and g_enable_flash is False
and g_enable_mem_efficient is False
):
raise AssertionError(
"No available backend for scaled_dot_product_attention was found."
)
if g_enable_math is True:
if g_enable_flash is False and g_enable_mem_efficient is False:
return "math"
if "gpu" not in place:
return "math"
if g_enable_flash is True and g_enable_mem_efficient is True:
return _select_sdp_cuda(head_dim)
if g_enable_flash is True:
return "flash_attn"
return "mem_efficient"
def flash_attention(
......@@ -24,6 +129,9 @@ def flash_attention(
dropout=0.0,
causal=False,
return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -57,7 +165,9 @@ def flash_attention(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -79,47 +189,81 @@ def flash_attention(
output = paddle.nn.functional.flash_attention(q, q, q, 0.9, False, False)
print(output)
"""
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
dropout,
causal,
return_softmax,
not training,
head_dim = query.shape[3]
sdp_func_name = _select_sdp(head_dim)
if sdp_func_name == "flash_attn":
if in_dynamic_mode():
(result_attention, result_softmax,) = _C_ops.flash_attn(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(paddle.float32)
seed_offset = helper.create_variable_for_type_inference(paddle.int64)
inputs = {
'q': query,
'k': key,
'v': value,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return result_attention, result_softmax if return_softmax else None
return out, softmax if return_softmax else None
else:
if sdp_func_name == "mem_efficient":
from paddle.incubate.nn.memory_efficient_attention import (
memory_efficient_attention,
)
helper = LayerHelper('flash_attn', **locals())
dtype = helper.input_dtype(input_param_name='q')
out = helper.create_variable_for_type_inference(dtype)
softmax = helper.create_variable_for_type_inference(dtype)
softmax_lse = helper.create_variable_for_type_inference(paddle.float32)
seed_offset = helper.create_variable_for_type_inference(paddle.int64)
inputs = {
'q': query,
'k': key,
'v': value,
}
outputs = {
'out': out,
'softmax': softmax,
'softmax_lse': softmax_lse,
'seed_offset': seed_offset,
}
helper.append_op(
type='flash_attn',
inputs=inputs,
outputs=outputs,
attrs={
'dropout': dropout,
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
},
)
return out, softmax if return_softmax else None
output = memory_efficient_attention(
query,
key,
value,
attn_bias=None,
p=dropout,
scale=None,
training=training,
)
return output, None
else:
return _math_attention(
query,
key,
value,
dropout_rate=dropout,
causal=causal,
return_softmax=return_softmax,
training=training,
)
def flash_attn_unpadded(
......@@ -134,6 +278,8 @@ def flash_attn_unpadded(
dropout=0.0,
causal=False,
return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -174,6 +320,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
......@@ -203,6 +351,7 @@ def flash_attn_unpadded(
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
......@@ -210,6 +359,7 @@ def flash_attn_unpadded(
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
......@@ -225,6 +375,7 @@ def flash_attn_unpadded(
'v': value,
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -244,6 +395,10 @@ def flash_attn_unpadded(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
scaled_dot_product_attention = flash_attention
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册