flash_attn_grad_kernel.cu 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/flash_attn_grad_kernel.h"
16
#include "glog/logging.h"  // For VLOG()
17 18
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
19
#include "paddle/phi/core/flags.h"
20 21 22 23 24 25 26 27
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
28
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
29 30
#endif

31 32
DECLARE_bool(cudnn_deterministic);

33 34 35
namespace phi {

template <typename T, typename Context>
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
void FlashAttnUnpaddedGradKernel(const Context& ctx,
                                 const DenseTensor& q,
                                 const DenseTensor& k,
                                 const DenseTensor& v,
                                 const DenseTensor& cu_seqlens_q,
                                 const DenseTensor& cu_seqlens_k,
                                 const DenseTensor& out,
                                 const DenseTensor& softmax_lse,
                                 const DenseTensor& seed_offset,
                                 const DenseTensor& dout,
                                 int64_t max_seqlen_q,
                                 int64_t max_seqlen_k,
                                 float scale,
                                 float dropout,
                                 bool causal,
                                 DenseTensor* dq,
                                 DenseTensor* dk,
                                 DenseTensor* dv) {
54 55 56 57 58
#ifdef PADDLE_WITH_FLASHATTN
  ctx.template Alloc<T>(dq);
  ctx.template Alloc<T>(dk);
  ctx.template Alloc<T>(dv);

59
  const cudaStream_t stream = ctx.stream();
60

C
Chitsing KUI 已提交
61
  // q,k,v [total_*, num_heads, head_dim]
62 63

  auto dims = q.dims();
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  const int64_t total_q = dims[0];
  const int batch_size = cu_seqlens_q.numel() - 1;
  const int num_heads = dims[1];
  const int head_size_og = dout.dims()[2];
  const int head_size = dims[2];
  const int total_k = k.dims()[0];
  const int num_heads_k = k.dims()[1];

  // TODO(umiswing): add deterministic in fa2.
  // int num_splits = 0;  // 0 for an internal heuristic, which is optimal
  // if (FLAGS_cudnn_deterministic) {
  //   num_splits = 1;
  // }

  const bool zero_tensors = false;

  // TODO(umiswing): add shape check
  PADDLE_ENFORCE_EQ(
      head_size_og,
83
      head_size,
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
      phi::errors::InvalidArgument(
          "flash_attn_bwd receive input with head_size_og == head_size"));

  FlashAttnBwdParamsV2 params =
      FlashAttnBwdParamsV2(ctx,
                           batch_size,
                           max_seqlen_q,
                           max_seqlen_k,
                           num_heads,
                           num_heads_k,
                           head_size,
                           dropout,
                           scale,
                           causal,
                           q.dtype(),
                           seed_offset.data<int64_t>());

  VLOG(4) << "FlashAttn bwd seed: " << params.seed
          << ", offset: " << params.offset;

  const bool succ =
      phi::dynload::flash_attn_varlen_bwd(dout.data(),
                                          q.data(),
                                          k.data(),
                                          v.data(),
                                          out.data(),
                                          params.softmax_d.data(),
                                          softmax_lse.data(),
                                          cu_seqlens_q.data<int32_t>(),
                                          cu_seqlens_k.data<int32_t>(),
                                          params.rng_state.data(),
                                          dq->data(),
                                          dk->data(),
                                          dv->data(),
                                          params.dq_accum.data(),
                                          params.batch_size,
                                          params.max_seqlen_q,
                                          params.max_seqlen_k,
                                          params.seqlen_q_rounded,
                                          params.seqlen_k_rounded,
                                          params.num_heads,
                                          params.num_heads_k,
                                          params.head_size,
                                          params.head_size_rounded,
                                          params.dropout,
                                          params.scale,
                                          params.causal,
                                          params.is_bf16,
                                          stream,
                                          params.seed,
                                          params.offset);
135 136 137 138

  if (!succ) {
    PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
  }
139 140 141
#else
  PADDLE_THROW(phi::errors::Unimplemented(
      "FlashAttention is unsupported, please set use_flash_attn to false."));
142 143 144
#endif
}

C
Chitsing KUI 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
                         const DenseTensor& q,
                         const DenseTensor& k,
                         const DenseTensor& v,
                         const DenseTensor& out,
                         const DenseTensor& softmax_lse,
                         const DenseTensor& seed_offset,
                         const DenseTensor& dout,
                         float dropout,
                         bool causal,
                         DenseTensor* dq,
                         DenseTensor* dk,
                         DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
  // q,k,v [batch_size, seq_len, num_heads, head_dim]

  auto dims = q.dims();
163 164 165 166 167 168 169 170 171 172 173 174 175 176
  const int batch_size = dims[0];
  const int seqlen_q = dims[1];
  const int num_heads = dims[2];
  const int head_size_og = dout.dims()[3];
  const int head_size = dims[3];
  const int seqlen_k = k.dims()[1];
  const int num_heads_k = k.dims()[2];

  // TODO(umiswing): add shape check
  PADDLE_ENFORCE_EQ(
      head_size_og,
      head_size,
      phi::errors::InvalidArgument(
          "flash_attn_bwd receive input with head_size_og == head_size"));
C
Chitsing KUI 已提交
177

178 179 180
  VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
          << "], v[" << v.dims() << "]";

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
  const float scale = 1.0f / std::sqrt(head_size);

  FlashAttnBwdParamsV2 params =
      FlashAttnBwdParamsV2(ctx,
                           batch_size,
                           seqlen_q,
                           seqlen_k,
                           num_heads,
                           num_heads_k,
                           head_size,
                           dropout,
                           scale,
                           causal,
                           q.dtype(),
                           seed_offset.data<int64_t>());
C
Chitsing KUI 已提交
196

197 198 199
  ctx.template Alloc<T>(dq);
  ctx.template Alloc<T>(dk);
  ctx.template Alloc<T>(dv);
C
Chitsing KUI 已提交
200

201
  cudaStream_t stream = ctx.stream();
C
Chitsing KUI 已提交
202

203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
  VLOG(4) << "FlashAttn bwd seed: " << params.seed
          << ", offset: " << params.offset;

  const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
                                                 q.data(),
                                                 k.data(),
                                                 v.data(),
                                                 out.data(),
                                                 params.softmax_d.data(),
                                                 softmax_lse.data(),
                                                 params.rng_state.data(),
                                                 dq->data(),
                                                 dk->data(),
                                                 dv->data(),
                                                 params.dq_accum.data(),
                                                 params.batch_size,
                                                 params.max_seqlen_q,
                                                 params.max_seqlen_k,
                                                 params.seqlen_q_rounded,
                                                 params.seqlen_k_rounded,
                                                 params.num_heads,
                                                 params.num_heads_k,
                                                 params.head_size,
                                                 params.head_size_rounded,
                                                 params.dropout,
                                                 params.scale,
                                                 params.causal,
                                                 params.is_bf16,
                                                 stream,
                                                 params.seed,
                                                 params.offset);

  PADDLE_ENFORCE_EQ(
      succ,
      true,
      phi::errors::External("Error in Flash-Attention-2, detail information is",
                            phi::dynload::flash_attn_error()));
#else
  PADDLE_THROW(phi::errors::Unimplemented(
      "FlashAttention is unsupported, please set use_flash_attn to false."));
C
Chitsing KUI 已提交
243 244 245
#endif
}

246 247
}  // namespace phi

248
PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
C
Chitsing KUI 已提交
249 250
                   GPU,
                   ALL_LAYOUT,
251
                   phi::FlashAttnUnpaddedGradKernel,
C
Chitsing KUI 已提交
252 253
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {
C
Chitsing KUI 已提交
254
  kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);  // seed_offset
C
Chitsing KUI 已提交
255 256
}

257 258 259 260 261 262
PD_REGISTER_KERNEL(flash_attn_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::FlashAttnGradKernel,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {
C
Chitsing KUI 已提交
263
  kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);  // seed_offset
264
}