fused_softmax_mask_kernel.cu 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_utils.h"

namespace phi {
namespace fusion {

// T == fp16
D
duanyanhui 已提交
25
template <typename T, typename MT, int pow2_index>
26
__global__ void SoftmaxMaskFuseGPUKernel(const T* x_data,
D
duanyanhui 已提交
27
                                         const MT* mask_data,
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
                                         T* y_data,
                                         int batch_count,
                                         int key_seq_len) {
  // the forward gpu kernel
  constexpr int next_pow2 = 1 << pow2_index;
  constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
  constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
  constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
  constexpr int kOneLoadingCounts = 4;

  int data_first_idx =
      (blockDim.y *
           (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
       threadIdx.y) *
      kLocalBatchSize;

  int mask_fist_idx =
      (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
      kLocalBatchSize;

  // batch_count might not be a multiple of kLocalBatchSize. Check how
  // many batches have to computed within this WARP.
  int local_batches = batch_count - data_first_idx;
  if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;

  // might be many batches per warp. compute the index within the batch
  int local_idx = threadIdx.x;

  int x_offset = data_first_idx * key_seq_len + kOneLoadingCounts * local_idx;
  int mask_offset = mask_fist_idx * key_seq_len + kOneLoadingCounts * local_idx;
  x_data += x_offset;
  mask_data += mask_offset;
  y_data += x_offset;

  // using float for all inter compute
  float data[kLocalBatchSize][kLocalIterations];
  T temp_data[kOneLoadingCounts];
D
duanyanhui 已提交
65
  MT temp_mask[kOneLoadingCounts];
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

#pragma unroll
  for (int i = 0; i < kLocalBatchSize; ++i) {
    int batch_data = (i >= local_batches) ? 0 : key_seq_len;

#pragma unroll
    for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
      int data_index = kOneLoadingCounts * local_idx + ii * warp_size;

      if (data_index < batch_data) {
        int itr_idx = i * key_seq_len + ii * warp_size;

        // efficiently load data from global memory
        load_data(temp_data, x_data + itr_idx);
        load_data(temp_mask, mask_data + itr_idx);

#pragma unroll
        for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
          data[i][ii + counter] = static_cast<float>(temp_data[counter]) +
                                  static_cast<float>(temp_mask[counter]);
        }
      } else {
#pragma unroll
        for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
          data[i][ii + counter] = -std::numeric_limits<float>::infinity();
        }
      }
    }
  }

  // compute max_value
  // max value for each batch for current warp
  float samples_max_value[kLocalBatchSize];
#pragma unroll
  for (int i = 0; i < kLocalBatchSize; ++i) {
    samples_max_value[i] = data[i][0];
#pragma unroll
    for (int ii = 1; ii < kLocalIterations; ++ii) {
      samples_max_value[i] = (samples_max_value[i] > data[i][ii])
                                 ? samples_max_value[i]
                                 : data[i][ii];
    }
  }
  // max value for each batch for all warp
  warp_reduce<float, kLocalBatchSize, warp_size, MaxOP>(samples_max_value);

  // compute the sum for each batch for current warp
  float samples_sum[kLocalBatchSize]{0.0f};
#pragma unroll
  for (int i = 0; i < kLocalBatchSize; ++i) {
#pragma unroll
    for (int ii = 0; ii < kLocalIterations; ++ii) {
      data[i][ii] = std::exp((data[i][ii] - samples_max_value[i]));
      samples_sum[i] += data[i][ii];
    }
  }
  // samples_sum for each batch for all warp
  warp_reduce<float, kLocalBatchSize, warp_size, AddOP>(samples_sum);

  // load the result from device back to host
  T samples_out[kOneLoadingCounts];
#pragma unroll
  for (int i = 0; i < kLocalBatchSize; ++i) {
    if (i >= local_batches) break;
#pragma unroll
    for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
      int idx = kOneLoadingCounts * local_idx + ii * warp_size;
      if (idx < key_seq_len) {
#pragma unroll
        for (int counter = 0; counter < kOneLoadingCounts; ++counter) {
          samples_out[counter] = data[i][ii + counter] / samples_sum[i];
        }
        load_data(y_data + i * key_seq_len + ii * warp_size, samples_out);
      } else {
        break;
      }
    }
  }
}

// T only supports fp16
// leave as template only for future update
template <typename T, typename Context>
149 150 151 152
void FusedSoftmaxMaskKernel(const Context& dev_ctx,
                            const DenseTensor& x,
                            const DenseTensor& mask,
                            DenseTensor* out) {
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
  auto* x_data = x.data<T>();
  auto* y_data = dev_ctx.template Alloc<T>(out);

  auto x_dim = x.dims();
  auto mask_dim = mask.dims();
  auto batches = x_dim[0];
  auto attn_heads = x_dim[1];
  auto query_seq_len = x_dim[2];
  auto key_seq_len = x_dim[3];

  PADDLE_ENFORCE_GT(query_seq_len,
                    1,
                    phi::errors::InvalidArgument(
                        "Input x's second last dim must be large than 1 but "
                        "received the second last dimension of x is %d",
                        query_seq_len));

  PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
                    true,
                    phi::errors::InvalidArgument(
                        "Input x's last dim must be between [32, 8192) "
                        "received the last dimension of x is %d",
                        key_seq_len));

  PADDLE_ENFORCE_EQ(mask_dim[1],
                    1,
                    phi::errors::InvalidArgument(
                        "Input mask's second dim must be 1 "
                        "received the second dimension of mask is %d",
                        mask_dim[1]));

  // dim of x and mask must be equal
  for (size_t idx = 0; idx < 4; ++idx) {
    if (idx == 1) continue;
    PADDLE_ENFORCE_EQ(
        x_dim[idx],
        mask_dim[idx],
        phi::errors::InvalidArgument(
            "Input x's %dth dim should be equal with input mask's %dth dim "
            "but "
            "received the %dth dimension of x and mask are not equal "
            "the %dth dim of x is %d, while the %dth dim of mask is %d.",
            idx,
            idx,
            idx,
            idx,
            x_dim[idx],
            idx,
            mask_dim[idx]));
  }

  auto stream = dev_ctx.stream();

  int pow2_index = get_pow2(key_seq_len);
  const int next_pow2 = 1 << pow2_index;
  int batch_count = batches * attn_heads * query_seq_len;
  int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
  int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
  // use 128 threads per block to maximum gpu utilization
  constexpr int threads_per_block = 128;

  int warps_per_block = (threads_per_block / warp_size);
  int batches_per_block = warps_per_block * batches_per_warp;
  PADDLE_ENFORCE_EQ(
      query_seq_len % batches_per_block,
      0,
      phi::errors::InvalidArgument(
          "The query seq len (third dim of input X) must can divide the "
          "number of batches per block. The query seq len is %d, while "
          "the number of batches per block is %d.",
          query_seq_len,
          batches_per_block));
  dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
  dim3 threads(warp_size, warps_per_block, 1);

D
duanyanhui 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
  if (mask.dtype() == x.dtype()) {
    auto* mask_data = mask.data<T>();
    switch (pow2_index) {
      case 5:  // 32
        SoftmaxMaskFuseGPUKernel<T, T, 5><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 6:  // 64
        SoftmaxMaskFuseGPUKernel<T, T, 6><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 7:  // 128
        SoftmaxMaskFuseGPUKernel<T, T, 7><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 8:  // 256
        SoftmaxMaskFuseGPUKernel<T, T, 8><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 9:  // 512
        SoftmaxMaskFuseGPUKernel<T, T, 9><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 10:  // 1024
        SoftmaxMaskFuseGPUKernel<T, T, 10><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 11:  // 2048
        SoftmaxMaskFuseGPUKernel<T, T, 11><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 12:  // 4096
        SoftmaxMaskFuseGPUKernel<T, T, 12><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 13:  // 8192
        SoftmaxMaskFuseGPUKernel<T, T, 13><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      default:
        break;
    }
  } else if (mask.dtype() == phi::DataType::FLOAT32) {
    auto* mask_data = mask.data<float>();
    switch (pow2_index) {
      case 5:  // 32
        SoftmaxMaskFuseGPUKernel<T, float, 5><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 6:  // 64
        SoftmaxMaskFuseGPUKernel<T, float, 6><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 7:  // 128
        SoftmaxMaskFuseGPUKernel<T, float, 7><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 8:  // 256
        SoftmaxMaskFuseGPUKernel<T, float, 8><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 9:  // 512
        SoftmaxMaskFuseGPUKernel<T, float, 9><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 10:  // 1024
        SoftmaxMaskFuseGPUKernel<T, float, 10><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 11:  // 2048
        SoftmaxMaskFuseGPUKernel<T, float, 11><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 12:  // 4096
        SoftmaxMaskFuseGPUKernel<T, float, 12><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      case 13:  // 8192
        SoftmaxMaskFuseGPUKernel<T, float, 13><<<blocks, threads, 0, stream>>>(
            x_data, mask_data, y_data, batch_count, key_seq_len);
        break;
      default:
        break;
    }
312 313 314 315 316 317 318 319 320
  }
}

}  // namespace fusion
}  // namespace phi

PD_REGISTER_KERNEL(fused_softmax_mask,
                   GPU,
                   ALL_LAYOUT,
321
                   phi::fusion::FusedSoftmaxMaskKernel,
322 323
                   float,
                   phi::dtype::float16) {}