flash_attn_kernel.cu 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/flash_attn_kernel.h"

17
#include "glog/logging.h"  // For VLOG()
18 19
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
C
Chitsing KUI 已提交
20
#include "paddle/phi/core/enforce.h"
21
#include "paddle/phi/core/flags.h"
22 23 24 25 26 27 28 29 30 31
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#endif

32 33
DECLARE_bool(cudnn_deterministic);

34 35 36
namespace phi {

template <typename T, typename Context>
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
void FlashAttnUnpaddedKernel(
    const Context& ctx,
    const DenseTensor& q,
    const DenseTensor& k,
    const DenseTensor& v,
    const DenseTensor& cu_seqlens_q,
    const DenseTensor& cu_seqlens_k,
    const paddle::optional<DenseTensor>& fixed_seed_offset,
    int64_t max_seqlen_q,
    int64_t max_seqlen_k,
    float scale,
    float dropout,
    bool causal,
    bool return_softmax,
    bool is_test,
    const std::string& rng_name,
    DenseTensor* out,
    DenseTensor* softmax,
    DenseTensor* softmax_lse,
    DenseTensor* seed_offset) {
57
#ifdef PADDLE_WITH_FLASHATTN
S
sneaxiy 已提交
58 59
  if (is_test) dropout = 0.0f;

60 61 62 63 64
  ctx.template Alloc<T>(out);

  cudaStream_t stream = ctx.stream();
  bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false;

C
Chitsing KUI 已提交
65
  // q,k,v [total_*, num_heads, head_dim]
66 67

  auto dims = q.dims();
C
Chitsing KUI 已提交
68 69 70 71 72
  PADDLE_ENFORCE_EQ(
      dims.size(),
      3,
      phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
                                   "[total_seq_len, num_heads, head_dim]"));
73

C
Chitsing KUI 已提交
74 75 76
  int64_t total_q = dims[0];
  int64_t num_heads = dims[1];
  int64_t head_size = dims[2];
77

C
Chitsing KUI 已提交
78 79
  int64_t total_k = k.dims()[0];
  int64_t batch_size = cu_seqlens_q.numel() - 1;
80 81

  int num_splits = 0;  // 0 for an internal heuristic, which is optimal
82 83 84
  if (FLAGS_cudnn_deterministic) {
    num_splits = 1;
  }
85 86
  bool zero_tensors = false;

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
  uint64_t seed;
  uint64_t offset;

  if (fixed_seed_offset.get_ptr()) {
    const int64_t* fixed_seed_offset_data =
        fixed_seed_offset.get_ptr()->data<int64_t>();
    seed = static_cast<uint64_t>(fixed_seed_offset_data[0]);
    offset = static_cast<uint64_t>(fixed_seed_offset_data[1]);
  } else {
    uint64_t inc = batch_size * num_heads * 32;
    std::pair<uint64_t, uint64_t> seed_offset_pair;
    if (rng_name != "") {
      auto gen = phi::GetRandomSeedGenerator(rng_name);
      seed_offset_pair = gen->IncrementOffset(inc);
    } else {
      auto* gen = ctx.GetGenerator();
      seed_offset_pair = gen->IncrementOffset(inc);
    }
    seed = seed_offset_pair.first;
    offset = seed_offset_pair.second;
  }
108

109 110
  VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset
          << ", num_splits:" << num_splits;
111

112
  seed_offset->Resize({2});
113
  int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
114 115
  seed_offset_data[0] = static_cast<int64_t>(seed);
  seed_offset_data[1] = static_cast<int64_t>(offset);
116

C
Chitsing KUI 已提交
117 118
  int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16;

119 120 121 122
  softmax_lse->Resize({batch_size, num_heads, seq_len_q});
  ctx.template Alloc<float>(softmax_lse);

  if (return_softmax) {
C
Chitsing KUI 已提交
123
    // may allocate more space than *max_seqlen_k*
124
    int64_t blocksize_c = head_size > 64 ? 128 : 256;
C
Chitsing KUI 已提交
125 126 127 128 129 130 131 132
    int64_t seq_len_k =
        ((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c;
    if (max_seqlen_k <= 128) {
      seq_len_k = 128;
    } else if (max_seqlen_k <= 256) {
      seq_len_k = 256;
    }
    softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k});
133 134 135 136 137
    ctx.template Alloc<T>(softmax);
  }

  uint64_t workspace_size;

C
Chitsing KUI 已提交
138
  // TODO(kuizhiqing) pass allocation/empty func in capi to decouple
139 140
  // calculate workspace size before execution
  bool succ =
C
Chitsing KUI 已提交
141 142 143
      phi::dynload::flash_attn_fwd(q.data(),
                                   k.data(),
                                   v.data(),
144 145 146 147 148 149 150 151
                                   nullptr,  // for calculation workspace size
                                   cu_seqlens_q.data(),
                                   cu_seqlens_k.data(),
                                   total_q,
                                   total_k,
                                   batch_size,
                                   num_heads,
                                   head_size,
C
Chitsing KUI 已提交
152 153
                                   max_seqlen_q,
                                   max_seqlen_k,
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
                                   dropout,
                                   scale,
                                   zero_tensors,
                                   causal,
                                   is_bf16,
                                   num_splits,
                                   softmax_lse->data(),
                                   return_softmax ? softmax->data() : nullptr,
                                   nullptr,
                                   &workspace_size,
                                   stream,
                                   seed,
                                   offset);

  if (!succ) {
    PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
  }

  DenseTensor workspace;
  if (workspace_size > 0) {
    workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))});
  }

  succ = phi::dynload::flash_attn_fwd(
C
Chitsing KUI 已提交
178 179 180
      q.data(),
      k.data(),
      v.data(),
181 182 183 184 185 186 187 188
      out->data(),
      cu_seqlens_q.data(),
      cu_seqlens_k.data(),
      total_q,
      total_k,
      batch_size,
      num_heads,
      head_size,
C
Chitsing KUI 已提交
189 190
      max_seqlen_q,
      max_seqlen_k,
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
      dropout,
      scale,
      zero_tensors,
      causal,
      is_bf16,
      num_splits,
      softmax_lse->data(),
      return_softmax ? softmax->data() : nullptr,
      workspace_size > 0 ? workspace.data() : nullptr,
      &workspace_size,
      stream,
      seed,
      offset);

  if (!succ) {
    PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
  }

#endif
}

C
Chitsing KUI 已提交
212 213 214 215 216
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
                     const DenseTensor& q,
                     const DenseTensor& k,
                     const DenseTensor& v,
217
                     const paddle::optional<DenseTensor>& fixed_seed_offset,
C
Chitsing KUI 已提交
218 219 220
                     float dropout,
                     bool causal,
                     bool return_softmax,
S
sneaxiy 已提交
221
                     bool is_test,
222
                     const std::string& rng_name,
C
Chitsing KUI 已提交
223 224
                     DenseTensor* out,
                     DenseTensor* softmax,
225
                     DenseTensor* softmax_lse,
C
Chitsing KUI 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
                     DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
  // q,k,v [batch_size, seq_len, num_heads, head_dim]

  auto dims = q.dims();
  PADDLE_ENFORCE_EQ(dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "flash_attn receive input with dim "
                        "[batch_size, seq_len, num_heads, head_dim]"));

  int64_t batch_size = dims[0];
  int64_t seq_len_q = dims[1];
  int64_t num_heads = dims[2];
  int64_t head_size = dims[3];

  int64_t seq_len_k = k.dims()[1];

  int64_t total_q = batch_size * seq_len_q;
  int64_t total_k = batch_size * seq_len_k;

  float scale = 1.0f / std::sqrt(head_size);

249 250 251
  VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
          << "], v[" << v.dims() << "]";

252 253 254 255
  DenseTensor q_t_s, k_t_s, v_t_s;
  q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
  k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
  v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
C
Chitsing KUI 已提交
256 257 258 259 260 261 262 263

  DenseTensor cu_seqlens_q;
  DenseTensor cu_seqlens_k;
  ArangeNullaryKernel<int32_t, Context>(
      ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q);
  ArangeNullaryKernel<int32_t, Context>(
      ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k);

264 265 266 267 268 269
  FlashAttnUnpaddedKernel<T, Context>(ctx,
                                      q_t_s,
                                      k_t_s,
                                      v_t_s,
                                      cu_seqlens_q,
                                      cu_seqlens_k,
270
                                      fixed_seed_offset,
271 272 273 274 275 276
                                      seq_len_q,
                                      seq_len_k,
                                      scale,
                                      dropout,
                                      causal,
                                      return_softmax,
S
sneaxiy 已提交
277
                                      is_test,
278
                                      rng_name,
279 280 281 282
                                      out,
                                      softmax,
                                      softmax_lse,
                                      seed_offset);
C
Chitsing KUI 已提交
283 284 285 286

#endif
}

287 288
}  // namespace phi

289
PD_REGISTER_KERNEL(flash_attn_unpadded,
C
Chitsing KUI 已提交
290 291
                   GPU,
                   ALL_LAYOUT,
292
                   phi::FlashAttnUnpaddedKernel,
C
Chitsing KUI 已提交
293
                   phi::dtype::float16,
294 295 296 297
                   phi::dtype::bfloat16) {
  kernel->InputAt(5).SetBackend(
      phi::Backend::ALL_BACKEND);  // fixed_seed_offset
}
C
Chitsing KUI 已提交
298

299 300 301 302 303
PD_REGISTER_KERNEL(flash_attn,
                   GPU,
                   ALL_LAYOUT,
                   phi::FlashAttnKernel,
                   phi::dtype::float16,
304 305 306 307
                   phi::dtype::bfloat16) {
  kernel->InputAt(3).SetBackend(
      phi::Backend::ALL_BACKEND);  // fixed_seed_offset
}