未验证 提交 00ac8014 编写于 作者: C Chitsing KUI 提交者: GitHub

[FlashAttn] add flash randomness control (#52902)

* add flash randomness control

* fix VLOG undefied
上级 67c6cfe0
...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn) ...@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn) set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn) set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc) set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_INCLUDE_DIR set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include" "${FLASHATTN_INSTALL_DIR}/include"
......
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad - backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
...@@ -628,7 +628,7 @@ ...@@ -628,7 +628,7 @@
data_type: q data_type: q
- backward_op : flash_attn_unpadded_grad - backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta : infer_meta :
......
...@@ -678,8 +678,9 @@ ...@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad backward : fill_diagonal_tensor_grad
- op : flash_attn - op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
...@@ -690,8 +691,9 @@ ...@@ -690,8 +691,9 @@
backward : flash_attn_grad backward : flash_attn_grad
- op : flash_attn_unpadded - op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta : infer_meta :
func : FlashAttnInferMeta func : FlashAttnInferMeta
param : [q, k, v] param : [q, k, v]
......
...@@ -20,33 +20,38 @@ ...@@ -20,33 +20,38 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const DenseTensor& q, const Context& ctx,
const DenseTensor& k, const DenseTensor& q,
const DenseTensor& v, const DenseTensor& k,
const DenseTensor& cu_seqlens_q, const DenseTensor& v,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_q,
int64_t max_seqlen_q, const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_k, const paddle::optional<DenseTensor>& fixed_seed_offset,
float scale, int64_t max_seqlen_q,
float dropout, int64_t max_seqlen_k,
bool causal, float scale,
bool return_softmax, float dropout,
bool is_test, bool causal,
DenseTensor* out, bool return_softmax,
DenseTensor* softmax, bool is_test,
DenseTensor* softmax_lse, const std::string& rng_name,
DenseTensor* seed_offset); DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx, void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
...@@ -25,6 +27,8 @@ ...@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#endif #endif
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, ...@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1; int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false; bool zero_tensors = false;
const int64_t* seed_offset_data = seed_offset.data<int64_t>(); const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]); uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]); uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q}); DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
...@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx, ...@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h" #include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h"
...@@ -28,26 +29,31 @@ ...@@ -28,26 +29,31 @@
#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/backends/dynload/flashattn.h"
#endif #endif
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx, void FlashAttnUnpaddedKernel(
const DenseTensor& q, const Context& ctx,
const DenseTensor& k, const DenseTensor& q,
const DenseTensor& v, const DenseTensor& k,
const DenseTensor& cu_seqlens_q, const DenseTensor& v,
const DenseTensor& cu_seqlens_k, const DenseTensor& cu_seqlens_q,
int64_t max_seqlen_q, const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_k, const paddle::optional<DenseTensor>& fixed_seed_offset,
float scale, int64_t max_seqlen_q,
float dropout, int64_t max_seqlen_k,
bool causal, float scale,
bool return_softmax, float dropout,
bool is_test, bool causal,
DenseTensor* out, bool return_softmax,
DenseTensor* softmax, bool is_test,
DenseTensor* softmax_lse, const std::string& rng_name,
DenseTensor* seed_offset) { DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN #ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f; if (is_test) dropout = 0.0f;
...@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx, ...@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1; int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false; bool zero_tensors = false;
auto gen = ctx.GetGenerator(); uint64_t seed;
uint64_t inc = batch_size * num_heads * 32; uint64_t offset;
auto seed_offset_pair = gen->IncrementOffset(inc);
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
uint64_t seed = seed_offset_pair.first; VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
uint64_t offset = seed_offset_pair.second; << ", num_splits:" << num_splits;
seed_offset->Resize({2}); seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset); int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed); seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset); seed_offset_data[1] = static_cast<int64_t>(offset);
...@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx, ...@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q, const DenseTensor& q,
const DenseTensor& k, const DenseTensor& k,
const DenseTensor& v, const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout, float dropout,
bool causal, bool causal,
bool return_softmax, bool return_softmax,
bool is_test, bool is_test,
const std::string& rng_name,
DenseTensor* out, DenseTensor* out,
DenseTensor* softmax, DenseTensor* softmax,
DenseTensor* softmax_lse, DenseTensor* softmax_lse,
...@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx, ...@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size); float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s; DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
...@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s, v_t_s,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset,
seq_len_q, seq_len_q,
seq_len_k, seq_len_k,
scale, scale,
...@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx, ...@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal, causal,
return_softmax, return_softmax,
is_test, is_test,
rng_name,
out, out,
softmax, softmax,
softmax_lse, softmax_lse,
...@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, ...@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel, phi::FlashAttnUnpaddedKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn, PD_REGISTER_KERNEL(flash_attn,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FlashAttnKernel, phi::FlashAttnKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
...@@ -38,3 +38,4 @@ from . import dist_shape ...@@ -38,3 +38,4 @@ from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
...@@ -24,6 +24,9 @@ def flash_attention( ...@@ -24,6 +24,9 @@ def flash_attention(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -57,7 +60,9 @@ def flash_attention( ...@@ -57,7 +60,9 @@ def flash_attention(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -84,10 +89,12 @@ def flash_attention( ...@@ -84,10 +89,12 @@ def flash_attention(
query, query,
key, key,
value, value,
fixed_seed_offset,
dropout, dropout,
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -101,6 +108,7 @@ def flash_attention( ...@@ -101,6 +108,7 @@ def flash_attention(
'q': query, 'q': query,
'k': key, 'k': key,
'v': value, 'v': value,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -117,6 +125,7 @@ def flash_attention( ...@@ -117,6 +125,7 @@ def flash_attention(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
...@@ -134,6 +143,8 @@ def flash_attn_unpadded( ...@@ -134,6 +143,8 @@ def flash_attn_unpadded(
dropout=0.0, dropout=0.0,
causal=False, causal=False,
return_softmax=False, return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True, training=True,
name=None, name=None,
): ):
...@@ -174,6 +185,8 @@ def flash_attn_unpadded( ...@@ -174,6 +185,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio. dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode. causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax. return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase. training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
...@@ -203,6 +216,7 @@ def flash_attn_unpadded( ...@@ -203,6 +216,7 @@ def flash_attn_unpadded(
value, value,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
scale, scale,
...@@ -210,6 +224,7 @@ def flash_attn_unpadded( ...@@ -210,6 +224,7 @@ def flash_attn_unpadded(
causal, causal,
return_softmax, return_softmax,
not training, not training,
rng_name,
) )
return result_attention, result_softmax if return_softmax else None return result_attention, result_softmax if return_softmax else None
...@@ -225,6 +240,7 @@ def flash_attn_unpadded( ...@@ -225,6 +240,7 @@ def flash_attn_unpadded(
'v': value, 'v': value,
'cu_seqlens_q': cu_seqlens_q, 'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k, 'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
} }
outputs = { outputs = {
'out': out, 'out': out,
...@@ -244,6 +260,7 @@ def flash_attn_unpadded( ...@@ -244,6 +260,7 @@ def flash_attn_unpadded(
'causal': causal, 'causal': causal,
'return_softmax': return_softmax, 'return_softmax': return_softmax,
'is_test': not training, 'is_test': not training,
'rng_name': rng_name,
}, },
) )
return out, softmax if return_softmax else None return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册