flash_attn_grad_kernel.cu 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
16
#include "glog/logging.h"  // For VLOG()
17 18
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
19
#include "paddle/phi/core/flags.h"
20 21 22 23
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
24
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
25
#include "paddle/phi/kernels/reshape_kernel.h"
26

27 28
DECLARE_bool(cudnn_deterministic);

29 30
namespace phi {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
template <typename T, typename Context>
void FlashAttnUnpaddedGradImpl(const Context& ctx,
                               const DenseTensor& q,
                               const DenseTensor& k,
                               const DenseTensor& v,
                               const DenseTensor& cu_seqlens_q,
                               const DenseTensor& cu_seqlens_k,
                               const DenseTensor& out,
                               const DenseTensor& softmax_lse,
                               const DenseTensor& seed_offset,
                               const paddle::optional<DenseTensor>& attn_mask,
                               const DenseTensor& dout,
                               int64_t max_seqlen_q,
                               int64_t max_seqlen_k,
                               float scale,
                               float dropout,
                               bool causal,
                               DenseTensor* dq,
                               DenseTensor* dk,
                               DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
  const cudaStream_t stream = ctx.stream();

  auto dims = q.dims();
  int64_t total_q = dims[0];
  int64_t num_heads = dims[1];
  int64_t head_size = dims[2];

  int64_t total_k = k.dims()[0];
  int64_t batch_size = cu_seqlens_q.numel() - 1;

  PADDLE_ENFORCE_NE(causal,
                    true,
                    phi::errors::InvalidArgument(
                        "attn_mask is not nullptr, causal can not be true"));

  PADDLE_ENFORCE_EQ(
      head_size == 32 || head_size == 64 || head_size == 128,
      true,
      phi::errors::InvalidArgument("The head_dim is expected to be either 32, "
                                   "64, or 128, but recieved %d.",
                                   head_size));
  const int64_t* seed_offset_data = seed_offset.data<int64_t>();
  uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
  uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
  VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset;

  int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
  DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seqlen_q});

  const DenseTensor* attn_mask_tensor = attn_mask.get_ptr();
  std::vector<int64_t> mask_dims = GetAttnMaskDims(attn_mask_tensor);

  int fa_num_splits = 0;
  bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16;
  float fa_with_mask_scale = 1.0f;
  bool fa_zero_tensors = false;

  uint64_t workspace_size;
90 91 92 93 94

  int64_t q_size = total_q * num_heads * head_size;
  DenseTensor scaled_q = Empty<T>(ctx, {total_q, num_heads, head_size});
  ComputeScaleQ(ctx, q_size, scale, q.data<T>(), scaled_q.data<T>());

95
  bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
96
      static_cast<const void*>(scaled_q.data<T>()),
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
      static_cast<const void*>(k.data()),
      static_cast<const void*>(v.data()),
      static_cast<void*>(dq->data()),
      static_cast<void*>(dk->data()),
      static_cast<void*>(dv->data()),
      nullptr,  // set out to nullptr to calculate workspace size
      dout.data(),
      static_cast<const int32_t*>(cu_seqlens_q.data()),
      static_cast<const int32_t*>(cu_seqlens_k.data()),
      total_q,
      total_k,
      batch_size,
      num_heads,
      head_size,
      max_seqlen_q,
      max_seqlen_k,
      dropout,
      fa_with_mask_scale,
      fa_zero_tensors,
      fa_is_bf16,
      fa_num_splits,
      static_cast<const void*>(softmax_lse.data()),
      static_cast<void*>(dsoftmax.data()),
      nullptr,
      nullptr,
      &workspace_size,
      stream,
      seed,
      offset,
      attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
      nullptr,
      mask_dims.data() ? mask_dims.data() : nullptr,
      nullptr);
  CheckFlashAttnStatus(succ);
  DenseTensor workspace;
  if (workspace_size > 0) {
    workspace = Empty<float>(
        ctx, {static_cast<int64_t>(workspace_size / sizeof(float))});
  }

  succ = phi::dynload::flash_attn_bwd_with_bias_and_mask(
138
      static_cast<const void*>(scaled_q.data<T>()),
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
      static_cast<const void*>(k.data()),
      static_cast<const void*>(v.data()),
      static_cast<void*>(dq->data()),
      static_cast<void*>(dk->data()),
      static_cast<void*>(dv->data()),
      out.data(),  // set out to nullptr to calculate workspace size
      dout.data(),
      static_cast<const int32_t*>(cu_seqlens_q.data()),
      static_cast<const int32_t*>(cu_seqlens_k.data()),
      total_q,
      total_k,
      batch_size,
      num_heads,
      head_size,
      max_seqlen_q,
      max_seqlen_k,
      dropout,
      fa_with_mask_scale,
      fa_zero_tensors,
      fa_is_bf16,
      fa_num_splits,
      static_cast<const void*>(softmax_lse.data()),
      static_cast<void*>(dsoftmax.data()),
      nullptr,
      workspace_size > 0 ? workspace.data() : nullptr,
      &workspace_size,
      stream,
      seed,
      offset,
      attn_mask_tensor ? attn_mask_tensor->data() : nullptr,
      nullptr,
      mask_dims.data() ? mask_dims.data() : nullptr,
      nullptr);
  CheckFlashAttnStatus(succ);

  ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else
  RaiseNotSupportedError();
#endif
}

180
template <typename T, typename Context>
181 182 183 184 185 186 187 188 189
void FlashAttnUnpaddedGradKernel(const Context& ctx,
                                 const DenseTensor& q,
                                 const DenseTensor& k,
                                 const DenseTensor& v,
                                 const DenseTensor& cu_seqlens_q,
                                 const DenseTensor& cu_seqlens_k,
                                 const DenseTensor& out,
                                 const DenseTensor& softmax_lse,
                                 const DenseTensor& seed_offset,
190
                                 const paddle::optional<DenseTensor>& attn_mask,
191 192 193 194 195 196 197 198 199
                                 const DenseTensor& dout,
                                 int64_t max_seqlen_q,
                                 int64_t max_seqlen_k,
                                 float scale,
                                 float dropout,
                                 bool causal,
                                 DenseTensor* dq,
                                 DenseTensor* dk,
                                 DenseTensor* dv) {
200 201 202 203 204
#ifdef PADDLE_WITH_FLASHATTN
  ctx.template Alloc<T>(dq);
  ctx.template Alloc<T>(dk);
  ctx.template Alloc<T>(dv);

205
  const cudaStream_t stream = ctx.stream();
206

C
Chitsing KUI 已提交
207
  // q,k,v [total_*, num_heads, head_dim]
208
  auto dims = q.dims();
209

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
  if (attn_mask.get_ptr()) {
    FlashAttnUnpaddedGradImpl<T, Context>(ctx,
                                          q,
                                          k,
                                          v,
                                          cu_seqlens_q,
                                          cu_seqlens_k,
                                          out,
                                          softmax_lse,
                                          seed_offset,
                                          attn_mask,
                                          dout,
                                          max_seqlen_q,
                                          max_seqlen_k,
                                          scale,
                                          dropout,
                                          causal,
                                          dq,
                                          dk,
                                          dv);
  } else {
    const int64_t total_q = dims[0];
    const int64_t batch_size = cu_seqlens_q.numel() - 1;
    const int64_t num_heads = dims[1];
    const int64_t head_size_og = dout.dims()[2];
    const int64_t head_size = dims[2];
    const int64_t total_k = k.dims()[0];
    const int64_t num_heads_k = k.dims()[1];

    // TODO(umiswing): add deterministic in fa2.
    // int num_splits = 0;  // 0 for an internal heuristic, which is optimal
    // if (FLAGS_cudnn_deterministic) {
    //   num_splits = 1;
    // }

    // TODO(umiswing): add shape check
    PADDLE_ENFORCE_EQ(
        head_size_og,
        head_size,
        phi::errors::InvalidArgument(
            "flash_attn_bwd receive input with head_size_og == head_size"));
251

252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    FlashAttnBwdParamsV2 params =
        FlashAttnBwdParamsV2(ctx,
                             batch_size,
                             max_seqlen_q,
                             max_seqlen_k,
                             num_heads,
                             num_heads_k,
                             head_size,
                             dropout,
                             scale,
                             causal,
                             q.dtype(),
                             seed_offset.data<int64_t>());

    VLOG(10) << "FlashAttn bwd seed: " << params.seed
             << ", offset: " << params.offset;

    bool succ =
        phi::dynload::flash_attn_varlen_bwd(dout.data(),
                                            q.data(),
                                            k.data(),
                                            v.data(),
                                            out.data(),
                                            params.softmax_d.data(),
                                            softmax_lse.data(),
                                            cu_seqlens_q.data<int32_t>(),
                                            cu_seqlens_k.data<int32_t>(),
                                            params.rng_state.data(),
                                            dq->data(),
                                            dk->data(),
                                            dv->data(),
                                            params.dq_accum.data(),
                                            params.batch_size,
                                            params.max_seqlen_q,
                                            params.max_seqlen_k,
                                            params.seqlen_q_rounded,
                                            params.seqlen_k_rounded,
                                            params.num_heads,
                                            params.num_heads_k,
                                            params.head_size,
                                            params.head_size_rounded,
                                            params.dropout,
                                            params.scale,
                                            params.causal,
                                            params.is_bf16,
                                            stream,
                                            params.seed,
                                            params.offset);
    CheckFlashAttnStatus(succ);
301
  }
302
#else
303
  RaiseNotSupportedError();
304 305 306
#endif
}

C
Chitsing KUI 已提交
307 308 309 310 311 312 313 314
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
                         const DenseTensor& q,
                         const DenseTensor& k,
                         const DenseTensor& v,
                         const DenseTensor& out,
                         const DenseTensor& softmax_lse,
                         const DenseTensor& seed_offset,
315
                         const paddle::optional<DenseTensor>& attn_mask,
C
Chitsing KUI 已提交
316 317 318 319 320 321 322 323 324
                         const DenseTensor& dout,
                         float dropout,
                         bool causal,
                         DenseTensor* dq,
                         DenseTensor* dk,
                         DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
  // q,k,v [batch_size, seq_len, num_heads, head_dim]

325 326 327 328 329 330 331 332 333 334 335
  const auto& dims = q.dims();
  const int64_t batch_size = dims[0];
  const int64_t seqlen_q = dims[1];
  const int64_t num_heads = dims[2];
  const int64_t head_size_og = dout.dims()[3];
  const int64_t head_size = dims[3];
  const int64_t seqlen_k = k.dims()[1];
  const int64_t num_heads_k = k.dims()[2];

  const int64_t total_q = batch_size * seqlen_q;
  const int64_t total_k = batch_size * seqlen_k;
336 337 338 339 340 341 342

  // TODO(umiswing): add shape check
  PADDLE_ENFORCE_EQ(
      head_size_og,
      head_size,
      phi::errors::InvalidArgument(
          "flash_attn_bwd receive input with head_size_og == head_size"));
C
Chitsing KUI 已提交
343

344 345
  VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
           << "], v[" << v.dims() << "]";
346

347
  const float scale = 1.0f / std::sqrt(head_size);
348 349 350 351 352
  if (attn_mask.get_ptr()) {
    DenseTensor q_t_s, k_t_s, v_t_s;
    q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size});
    k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size});
    v_t_s.ShareDataWith(v).Resize({total_k, num_heads, head_size});
353

354 355 356 357 358 359
    DenseTensor cu_seqlens_q;
    DenseTensor cu_seqlens_k;
    ArangeNullaryKernel<int32_t, Context>(
        ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q);
    ArangeNullaryKernel<int32_t, Context>(
        ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k);
360

361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
    FlashAttnUnpaddedGradKernel<T, Context>(ctx,
                                            q_t_s,
                                            k_t_s,
                                            v_t_s,
                                            cu_seqlens_q,
                                            cu_seqlens_k,
                                            out,
                                            softmax_lse,
                                            seed_offset,
                                            attn_mask,
                                            dout,
                                            seqlen_q,
                                            seqlen_k,
                                            scale,
                                            dropout,
                                            causal,
                                            dq,
                                            dk,
                                            dv);
  } else {
    FlashAttnBwdParamsV2 params =
        FlashAttnBwdParamsV2(ctx,
                             batch_size,
                             seqlen_q,
                             seqlen_k,
                             num_heads,
                             num_heads_k,
                             head_size,
                             dropout,
                             scale,
                             causal,
                             q.dtype(),
                             seed_offset.data<int64_t>());
394

395 396 397
    ctx.template Alloc<T>(dq);
    ctx.template Alloc<T>(dk);
    ctx.template Alloc<T>(dv);
398

399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
    cudaStream_t stream = ctx.stream();

    VLOG(10) << "FlashAttn bwd seed: " << params.seed
             << ", offset: " << params.offset;

    bool succ = phi::dynload::flash_attn_bwd(dout.data(),
                                             q.data(),
                                             k.data(),
                                             v.data(),
                                             out.data(),
                                             params.softmax_d.data(),
                                             softmax_lse.data(),
                                             params.rng_state.data(),
                                             dq->data(),
                                             dk->data(),
                                             dv->data(),
                                             params.dq_accum.data(),
                                             params.batch_size,
                                             params.max_seqlen_q,
                                             params.max_seqlen_k,
                                             params.seqlen_q_rounded,
                                             params.seqlen_k_rounded,
                                             params.num_heads,
                                             params.num_heads_k,
                                             params.head_size,
                                             params.head_size_rounded,
                                             params.dropout,
                                             params.scale,
                                             params.causal,
                                             params.is_bf16,
                                             stream,
                                             params.seed,
                                             params.offset);
    CheckFlashAttnStatus(succ);
  }
434
#else
435
  RaiseNotSupportedError();
C
Chitsing KUI 已提交
436 437 438
#endif
}

439 440
}  // namespace phi

441
PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
C
Chitsing KUI 已提交
442 443
                   GPU,
                   ALL_LAYOUT,
444
                   phi::FlashAttnUnpaddedGradKernel,
C
Chitsing KUI 已提交
445 446
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {
C
Chitsing KUI 已提交
447
  kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);  // seed_offset
C
Chitsing KUI 已提交
448 449
}

450 451 452 453 454 455
PD_REGISTER_KERNEL(flash_attn_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::FlashAttnGradKernel,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {
C
Chitsing KUI 已提交
456
  kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);  // seed_offset
457
}