flash_attn_kernel.cu 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/flash_attn_kernel.h"

17
#include "glog/logging.h"  // For VLOG()
18 19
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
C
Chitsing KUI 已提交
20
#include "paddle/phi/core/enforce.h"
21
#include "paddle/phi/core/flags.h"
22 23 24 25 26 27 28 29
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
30
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
31 32
#endif

33 34
DECLARE_bool(cudnn_deterministic);

35 36 37
namespace phi {

template <typename T, typename Context>
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
void FlashAttnUnpaddedKernel(
    const Context& ctx,
    const DenseTensor& q,
    const DenseTensor& k,
    const DenseTensor& v,
    const DenseTensor& cu_seqlens_q,
    const DenseTensor& cu_seqlens_k,
    const paddle::optional<DenseTensor>& fixed_seed_offset,
    int64_t max_seqlen_q,
    int64_t max_seqlen_k,
    float scale,
    float dropout,
    bool causal,
    bool return_softmax,
    bool is_test,
    const std::string& rng_name,
    DenseTensor* out,
    DenseTensor* softmax,
    DenseTensor* softmax_lse,
    DenseTensor* seed_offset) {
58
#ifdef PADDLE_WITH_FLASHATTN
S
sneaxiy 已提交
59

60 61 62 63
  ctx.template Alloc<T>(out);

  cudaStream_t stream = ctx.stream();

C
Chitsing KUI 已提交
64
  // q,k,v [total_*, num_heads, head_dim]
65 66

  auto dims = q.dims();
C
Chitsing KUI 已提交
67 68 69 70 71
  PADDLE_ENFORCE_EQ(
      dims.size(),
      3,
      phi::errors::InvalidArgument("flash_attn_raw receive input with dim "
                                   "[total_seq_len, num_heads, head_dim]"));
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
  const int64_t total_q = dims[0];
  const int num_heads = dims[1];
  const int head_size = dims[2];

  const int total_k = k.dims()[0];
  const int num_heads_k = k.dims()[1];
  const int batch_size = cu_seqlens_q.numel() - 1;

  // TODO(umiswing): add deterministic in fa2.
  // int num_splits = 0;  // 0 for an internal heuristic, which is optimal
  // if (FLAGS_cudnn_deterministic) {
  //   num_splits = 1;
  // }

  // TODO(umiswing): add shape check

  FlashAttnFwdParamsV2<T> params =
      FlashAttnFwdParamsV2<T>(ctx,
                              batch_size,
                              max_seqlen_q,
                              max_seqlen_k,
                              num_heads,
                              num_heads_k,
                              head_size,
                              dropout,
                              scale,
                              causal,
                              return_softmax,
                              q.dtype(),
                              is_test,
                              rng_name,
                              fixed_seed_offset.get_ptr(),
                              softmax,
                              softmax_lse,
                              seed_offset);

  VLOG(4) << "FlashAttn fwd seed: " << params.seed
          << ", offset: " << params.offset;

  const bool succ = phi::dynload::flash_attn_varlen_fwd(
C
Chitsing KUI 已提交
113 114 115
      q.data(),
      k.data(),
      v.data(),
116 117 118
      cu_seqlens_q.data<int32_t>(),
      cu_seqlens_k.data<int32_t>(),
      params.rng_state.data(),
119
      out->data(),
120
      params.return_softmax ? softmax->data() : nullptr,
121
      softmax_lse->data(),
122 123 124 125 126 127 128 129 130 131 132 133 134 135
      params.batch_size,
      params.max_seqlen_q,
      params.max_seqlen_k,
      params.seqlen_q_rounded,
      params.seqlen_k_rounded,
      params.num_heads,
      params.num_heads_k,
      params.head_size,
      params.head_size_rounded,
      params.dropout,
      params.scale,
      params.causal,
      params.return_softmax,
      params.is_bf16,
136
      stream,
137 138
      params.seed,
      params.offset);
139 140 141 142

  if (!succ) {
    PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
  }
143 144 145
#else
  PADDLE_THROW(phi::errors::Unimplemented(
      "FlashAttention is unsupported, please set use_flash_attn to false."));
146 147 148
#endif
}

C
Chitsing KUI 已提交
149 150 151 152 153
template <typename T, typename Context>
void FlashAttnKernel(const Context& ctx,
                     const DenseTensor& q,
                     const DenseTensor& k,
                     const DenseTensor& v,
154
                     const paddle::optional<DenseTensor>& fixed_seed_offset,
C
Chitsing KUI 已提交
155 156 157
                     float dropout,
                     bool causal,
                     bool return_softmax,
S
sneaxiy 已提交
158
                     bool is_test,
159
                     const std::string& rng_name,
C
Chitsing KUI 已提交
160 161
                     DenseTensor* out,
                     DenseTensor* softmax,
162
                     DenseTensor* softmax_lse,
C
Chitsing KUI 已提交
163 164 165 166 167 168 169 170 171 172 173
                     DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
  // q,k,v [batch_size, seq_len, num_heads, head_dim]

  auto dims = q.dims();
  PADDLE_ENFORCE_EQ(dims.size(),
                    4,
                    phi::errors::InvalidArgument(
                        "flash_attn receive input with dim "
                        "[batch_size, seq_len, num_heads, head_dim]"));

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
  const int batch_size = dims[0];
  const int seqlen_q = dims[1];
  const int num_heads = dims[2];
  const int head_size = dims[3];
  const int seqlen_k = k.dims()[1];
  const int num_heads_k = k.dims()[2];

  // TODO(umiswing): Add check shape

  const float scale = 1.0f / std::sqrt(head_size);

  FlashAttnFwdParamsV2<T> params =
      FlashAttnFwdParamsV2<T>(ctx,
                              batch_size,
                              seqlen_q,
                              seqlen_k,
                              num_heads,
                              num_heads_k,
                              head_size,
                              dropout,
                              scale,
                              causal,
                              return_softmax,
                              q.dtype(),
                              is_test,
                              rng_name,
                              fixed_seed_offset.get_ptr(),
                              softmax,
                              softmax_lse,
                              seed_offset);
C
Chitsing KUI 已提交
204

205 206
  VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims()
          << "], v[" << v.dims() << "]";
C
Chitsing KUI 已提交
207

208
  ctx.template Alloc<T>(out);
C
Chitsing KUI 已提交
209

210
  cudaStream_t stream = ctx.stream();
C
Chitsing KUI 已提交
211

212 213
  VLOG(4) << "FlashAttn fwd seed: " << params.seed
          << ", offset: " << params.offset;
214

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
  bool succ = phi::dynload::flash_attn_fwd(
      q.data(),
      k.data(),
      v.data(),
      params.rng_state.data(),
      out->data(),
      params.return_softmax ? params.softmax->data() : nullptr,
      params.softmax_lse->data(),
      params.batch_size,
      params.max_seqlen_q,
      params.max_seqlen_k,
      params.seqlen_q_rounded,
      params.seqlen_k_rounded,
      params.num_heads,
      params.num_heads_k,
      params.head_size,
      params.head_size_rounded,
      params.dropout,
      params.scale,
      params.causal,
      params.return_softmax,
      params.is_bf16,
      stream,
      params.seed,
      params.offset);
C
Chitsing KUI 已提交
240

241 242 243 244 245 246 247 248
  PADDLE_ENFORCE_EQ(
      succ,
      true,
      phi::errors::External("Error in Flash-Attention-2, detail information is",
                            phi::dynload::flash_attn_error()));
#else
  PADDLE_THROW(phi::errors::Unimplemented(
      "FlashAttention is unsupported, please set use_flash_attn to false."));
C
Chitsing KUI 已提交
249 250 251
#endif
}

252 253
}  // namespace phi

254
PD_REGISTER_KERNEL(flash_attn_unpadded,
C
Chitsing KUI 已提交
255 256
                   GPU,
                   ALL_LAYOUT,
257
                   phi::FlashAttnUnpaddedKernel,
C
Chitsing KUI 已提交
258
                   phi::dtype::float16,
259 260 261 262
                   phi::dtype::bfloat16) {
  kernel->InputAt(5).SetBackend(
      phi::Backend::ALL_BACKEND);  // fixed_seed_offset
}
C
Chitsing KUI 已提交
263

264 265 266 267 268
PD_REGISTER_KERNEL(flash_attn,
                   GPU,
                   ALL_LAYOUT,
                   phi::FlashAttnKernel,
                   phi::dtype::float16,
269 270 271 272
                   phi::dtype::bfloat16) {
  kernel->InputAt(3).SetBackend(
      phi::Backend::ALL_BACKEND);  // fixed_seed_offset
}