未验证 提交 00ac8014 编写于 作者: C Chitsing KUI 提交者: GitHub

[FlashAttn] add flash randomness control (#52902)

* add flash randomness control

* fix VLOG undefied
上级 67c6cfe0
......@@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc)
set(FLASHATTN_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)
set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
......
......@@ -617,7 +617,7 @@
inplace : (out_grad -> x_grad)
- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......@@ -628,7 +628,7 @@
data_type: q
- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
......
......@@ -678,8 +678,9 @@
backward : fill_diagonal_tensor_grad
- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......@@ -690,8 +691,9 @@
backward : flash_attn_grad
- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
......
......@@ -20,33 +20,38 @@
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset);
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......
......@@ -13,8 +13,10 @@
// limitations under the License.
#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
......@@ -25,6 +27,8 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
......@@ -65,12 +69,18 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seq_len_q});
......@@ -187,6 +197,9 @@ void FlashAttnGradKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......
......@@ -14,12 +14,13 @@
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "glog/logging.h" // For VLOG()
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
......@@ -28,26 +29,31 @@
#include "paddle/phi/backends/dynload/flashattn.h"
#endif
DECLARE_bool(cudnn_deterministic);
namespace phi {
template <typename T, typename Context>
void FlashAttnUnpaddedKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
void FlashAttnUnpaddedKernel(
const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
if (is_test) dropout = 0.0f;
......@@ -73,17 +79,38 @@ void FlashAttnUnpaddedKernel(const Context& ctx,
int64_t batch_size = cu_seqlens_q.numel() - 1;
int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
bool zero_tensors = false;
auto gen = ctx.GetGenerator();
uint64_t inc = batch_size * num_heads * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
uint64_t seed;
uint64_t offset;
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
} else {
uint64_t inc = batch_size * num_heads * 32;
std::pair<uint64_t, uint64_t> seed_offset_pair;
if (rng_name != "") {
auto gen = phi::GetRandomSeedGenerator(rng_name);
seed_offset_pair = gen->IncrementOffset(inc);
} else {
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed = seed_offset_pair.first;
offset = seed_offset_pair.second;
}
uint64_t seed = seed_offset_pair.first;
uint64_t offset = seed_offset_pair.second;
VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
seed_offset->Resize({2});
auto* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
seed_offset_data[0] = static_cast<int64_t>(seed);
seed_offset_data[1] = static_cast<int64_t>(offset);
......@@ -187,10 +214,12 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
float dropout,
bool causal,
bool return_softmax,
bool is_test,
const std::string& rng_name,
DenseTensor* out,
DenseTensor* softmax,
DenseTensor* softmax_lse,
......@@ -217,6 +246,9 @@ void FlashAttnKernel(const Context& ctx,
float scale = 1.0f / std::sqrt(head_size);
VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
DenseTensor q_t_s, k_t_s, v_t_s;
q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
......@@ -235,6 +267,7 @@ void FlashAttnKernel(const Context& ctx,
v_t_s,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
seq_len_q,
seq_len_k,
scale,
......@@ -242,6 +275,7 @@ void FlashAttnKernel(const Context& ctx,
causal,
return_softmax,
is_test,
rng_name,
out,
softmax,
softmax_lse,
......@@ -257,11 +291,17 @@ PD_REGISTER_KERNEL(flash_attn_unpadded,
ALL_LAYOUT,
phi::FlashAttnUnpaddedKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(5).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
PD_REGISTER_KERNEL(flash_attn,
GPU,
ALL_LAYOUT,
phi::FlashAttnKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->InputAt(3).SetBackend(
phi::Backend::ALL_BACKEND); // fixed_seed_offset
}
......@@ -38,3 +38,4 @@ from . import dist_shape
from . import dist_assign
from . import dist_scale
from . import dist_dropout
from . import dist_flash_attn
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedFlashAttn(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedFlashAttn("flash_attn"))
# Dist FlashAttn with Random Control
class DistributedFlashAttnImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
return True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if (
is_enable_auto_rand_ctrl()
and not op_dist_attr.is_recompute
and rank_id in op_dist_attr.process_mesh.process_ids
):
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
if (
len(kwargs.get('fixed_seed_offset', [])) > 0
or len(src_op.input("fixed_seed_offset")) > 0
):
# TODO(kuizhiqing) recompute should go here
pass
else:
# determinate rng
q_var = main_block._var_recursive(kwargs['q'][0])
k_var = main_block._var_recursive(kwargs['k'][0])
q_dims_mapping = op_dist_attr.get_input_dims_mapping(q_var.name)
k_dims_mapping = op_dist_attr.get_input_dims_mapping(k_var.name)
process_mesh = op_dist_attr.process_mesh
dims_mapping = q_dims_mapping[:3] + [q_dims_mapping[2]]
rng_name = determinate_rng(rank_id, dims_mapping, process_mesh)
assert rng_name is not None and rng_name != ""
src_op._set_attr('rng_name', rng_name)
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"flash_attn", DistributedFlashAttnImpl0("random_control")
)
......@@ -24,6 +24,9 @@ def flash_attention(
dropout=0.0,
causal=False,
return_softmax=False,
*,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -57,7 +60,9 @@ def flash_attention(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
training(bool): Whether it is in the training phase.
rng_name(str): The name to select Generator.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -84,10 +89,12 @@ def flash_attention(
query,
key,
value,
fixed_seed_offset,
dropout,
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
......@@ -101,6 +108,7 @@ def flash_attention(
'q': query,
'k': key,
'v': value,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -117,6 +125,7 @@ def flash_attention(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
......@@ -134,6 +143,8 @@ def flash_attn_unpadded(
dropout=0.0,
causal=False,
return_softmax=False,
fixed_seed_offset=None,
rng_name="",
training=True,
name=None,
):
......@@ -174,6 +185,8 @@ def flash_attn_unpadded(
dropout(float): The dropout ratio.
causal(bool): Whether enable causal mode.
return_softmax(bool): Whether to return softmax.
fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask.
rng_name(str): The name to select Generator.
training(bool): Whether it is in the training phase.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
......@@ -203,6 +216,7 @@ def flash_attn_unpadded(
value,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
max_seqlen_q,
max_seqlen_k,
scale,
......@@ -210,6 +224,7 @@ def flash_attn_unpadded(
causal,
return_softmax,
not training,
rng_name,
)
return result_attention, result_softmax if return_softmax else None
......@@ -225,6 +240,7 @@ def flash_attn_unpadded(
'v': value,
'cu_seqlens_q': cu_seqlens_q,
'cu_seqlens_k': cu_seqlens_k,
'fixed_seed_offset': fixed_seed_offset,
}
outputs = {
'out': out,
......@@ -244,6 +260,7 @@ def flash_attn_unpadded(
'causal': causal,
'return_softmax': return_softmax,
'is_test': not training,
'rng_name': rng_name,
},
)
return out, softmax if return_softmax else None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册