fused_attention_kernel.cu 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/fused_attention_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/matmul_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

namespace phi {
namespace sparse {

29
template <typename T>
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
                                     const int64_t* x_cols,
                                     const T* x_values,
                                     const T* kp_mask,
                                     const T* attn_mask,
                                     T* out_values,
                                     int M,
                                     int total_row_num,
                                     int num_heads,
                                     int batch_nnz) {
  // out = exp(x-x_max) / sum(exp(x-x_max))
  int row = blockIdx.x * blockDim.y + threadIdx.y;
  if (row >= total_row_num) return;

  int cur_batch = row / M;
  int cur_row = row % M;
  int crow_idx = cur_batch * (M + 1) + cur_row;
  int row_first = cur_batch * batch_nnz + static_cast<int>(x_crows[crow_idx]);
  int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
  if (row_nnz == 0) return;

  T max_val = -std::numeric_limits<T>::infinity();
52
  for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
53 54 55 56 57 58 59 60 61 62 63
    bool mask = false;
    int col_idx = static_cast<int>(x_cols[row_first + idx]);
    if (kp_mask != nullptr &&
        kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) {
      mask = true;
    }
    if (attn_mask != nullptr && attn_mask[cur_row * M + col_idx] == 0) {
      mask = true;
    }

    if (!mask) {
64 65 66
      T val = x_values[row_first + idx];
      if (val > max_val) {
        max_val = val;
67
      }
68 69 70 71 72
      out_values[row_first + idx] = val;
    } else {
      // Note corner case: when all elements of the row are masked, result
      // may be wrong because of exp('-inf' - '-inf'), just ignore now.
      out_values[row_first + idx] = -std::numeric_limits<T>::infinity();
73 74 75 76 77
    }
  }
  T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);

  T exp_sum = 0;
78 79 80 81 82
  for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
    auto functor = phi::funcs::CudaExpFunctor<T>();
    T exp = functor(out_values[row_first + idx] - row_max_val);
    exp_sum += exp;
    out_values[row_first + idx] = exp;
83 84 85
  }
  T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);

86 87
  for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
    out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
88 89 90 91
  }
}

template <typename T, typename Context>
92 93 94 95 96 97 98 99 100 101
void FusedAttentionCsrKernel(
    const Context& dev_ctx,
    const DenseTensor& query,
    const DenseTensor& key,
    const DenseTensor& value,
    const SparseCsrTensor& sparse_mask,
    const paddle::optional<DenseTensor>& key_padding_mask,
    const paddle::optional<DenseTensor>& attn_mask,
    DenseTensor* out,
    SparseCsrTensor* softmax) {
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
#if CUDA_VERSION >= 11070
  /* Check Shape */
  auto q_dim = query.dims();
  auto q_rank = q_dim.size();

  int total_row_num = 1;
  int batch_num = 1;
  for (int i = 0; i < q_rank - 1; ++i) {
    total_row_num *= q_dim[i];
    if (i < q_rank - 2) {
      batch_num *= q_dim[i];
    }
  }
  int M = q_dim[q_rank - 2];
  int N = q_dim[q_rank - 1];

  PADDLE_ENFORCE_EQ(query.dims().size(),
                    4,
                    phi::errors::InvalidArgument(" 'query' must be 4D Tensor"));
  PADDLE_ENFORCE_EQ(key.dims().size(),
                    4,
                    phi::errors::InvalidArgument(" 'key' must be 4D Tensor"));
  PADDLE_ENFORCE_EQ(value.dims().size(),
                    4,
                    phi::errors::InvalidArgument(" 'value' must be 4D Tensor"));

  PADDLE_ENFORCE_EQ(
      sparse_mask.dims().size(),
      3,
      phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
                                   "[batch_size*num_heads, seq_len, seq_len]"));
  PADDLE_ENFORCE_EQ(
      sparse_mask.dims()[0],
      q_dim[0] * q_dim[1],
      phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
                                   "[batch_size*num_heads, seq_len, seq_len]"));
  PADDLE_ENFORCE_EQ(
      sparse_mask.dims()[1],
      M,
      phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
                                   "[batch_size*num_heads, seq_len, seq_len]"));
  PADDLE_ENFORCE_EQ(
      sparse_mask.dims()[2],
      M,
      phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
                                   "[batch_size*num_heads, seq_len, seq_len]"));

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  const auto kp_mask_ptr = key_padding_mask.get_ptr();
  if (kp_mask_ptr) {
    PADDLE_ENFORCE_EQ(
        kp_mask_ptr->dims().size(),
        2,
        phi::errors::InvalidArgument(
            "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
    PADDLE_ENFORCE_EQ(
        kp_mask_ptr->dims()[0],
        q_dim[0],
        phi::errors::InvalidArgument(
            "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
    PADDLE_ENFORCE_EQ(
        kp_mask_ptr->dims()[1],
        M,
        phi::errors::InvalidArgument(
            "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
  }

  const auto attn_mask_ptr = attn_mask.get_ptr();
  if (attn_mask_ptr) {
    PADDLE_ENFORCE_EQ(attn_mask_ptr->dims().size(),
                      2,
                      phi::errors::InvalidArgument(
                          "shape of 'attn_mask' must be [seq_len, seq_len]"));
    PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[0],
                      M,
                      phi::errors::InvalidArgument(
                          "shape of 'attn_mask' must be [seq_len, seq_len]"));
    PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[1],
                      M,
                      phi::errors::InvalidArgument(
                          "shape of 'attn_mask' must be [seq_len, seq_len]"));
  }
183

184
  /* Step1: SDD Matmul, reuse matmul */
185 186 187 188 189
  SparseCsrTensor sdd_result;
  EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
  auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
  sparse_blas.SDDMM(false,
                    true,
190
                    static_cast<T>(1 / std::sqrt(N)),
191 192 193 194 195 196 197
                    query,
                    key,
                    static_cast<T>(0),
                    &sdd_result);

  EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);

198 199
  dim3 grid((total_row_num + 7) / 8);
  dim3 block(WARP_SIZE, 8);
200 201

  int batch_nnz = sdd_result.nnz() / batch_num;
202 203 204
  AttnSoftmaxGpuKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
      sdd_result.non_zero_crows().data<int64_t>(),
      sdd_result.non_zero_cols().data<int64_t>(),
205
      sdd_result.values().data<T>(),
206 207
      kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
      attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
208
      softmax->mutable_values()->data<T>(),
209 210 211 212
      M,
      total_row_num,
      q_dim[1],
      batch_nnz);
213 214

  softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]}));
215
  MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out);
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
#else
  PADDLE_THROW(
      phi::errors::Unimplemented("forward of 'sparse.nn.functional.attention' "
                                 "use 'cusparseCsrSetStridedBatch', which is "
                                 "completed supported from CUDA 11.7"));
#endif
}

}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(fused_attention_csr,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::FusedAttentionCsrKernel,
                   float,
                   double) {
  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}