multihead_matmul_op.cu 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {

#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32

template <typename T>
31
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
32 33
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
34
    val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
35 36 37 38 39 40 41 42
#else
    val += __shfl_xor(val, mask, warpSize);
#endif
  return val;
}

/* Calculate the sum of all elements in a block */
template <typename T>
43
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
44 45 46 47
  static __shared__ T shared[WARP_SIZE];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

48
  val = warpReduceSum<T>(val, mask);
49 50 51 52 53

  if (lane == 0) shared[wid] = val;

  __syncthreads();

54 55
  // align block_span to warpSize
  int block_span = (blockDim.x + warpSize - 1) >> 5;
56
  val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
57
  val = warpReduceSum<T>(val, mask);
58 59 60 61 62

  return val;
}

template <typename T>
63
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
64 65
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
66
    val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
67 68 69 70 71 72 73 74
#else
    val = max(val, __shfl_xor(val, mask, warpSize));
#endif
  return val;
}

/* Calculate the maximum of all elements in a block */
template <typename T>
75
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
76 77 78 79
  static __shared__ T shared[WARP_SIZE];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

80
  val = warpReduceMax(val, mask);
81 82 83 84 85

  if (lane == 0) shared[wid] = val;

  __syncthreads();

86 87 88 89
  // align block_span to warpSize
  int block_span = (blockDim.x + warpSize - 1) >> 5;
  val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
  val = warpReduceMax(val, mask);
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

  return val;
}

template <typename T>
__global__ void add_QKV(const T *Q, const T *K, const T *V, T *q_buf_,
                        T *k_buf_, T *v_buf_, const T *bias_q, const T *bias_k,
                        const T *bias_v, int batch_size, int seq_len,
                        int head_num, int size_per_head) {
  const T *data_ptr_q, *data_ptr_k, *data_ptr_v;
  const T *bias_ptr_q, *bias_ptr_k, *bias_ptr_v;

  int m = batch_size * seq_len;
  int n = head_num * size_per_head;

  int row_offset = (blockIdx.x % m) * n;

  data_ptr_q = Q + row_offset;
  data_ptr_k = K + row_offset;
  data_ptr_v = V + row_offset;
  // bias ptr
  bias_ptr_q = bias_q;
  bias_ptr_k = bias_k;
  bias_ptr_v = bias_v;

  int batch_id = (blockIdx.x % m) / seq_len;
  int head_id = threadIdx.x / size_per_head;
  int id_in_head = threadIdx.x % size_per_head;
  int word_start_id = (blockIdx.x) % seq_len;

#if __CUDA_ARCH__ >= 350
  T tmp_q = __ldg(&data_ptr_q[threadIdx.x]) + __ldg(&bias_ptr_q[threadIdx.x]);
  T tmp_k = __ldg(&data_ptr_k[threadIdx.x]) + __ldg(&bias_ptr_k[threadIdx.x]);
  T tmp_v = __ldg(&data_ptr_v[threadIdx.x]) + __ldg(&bias_ptr_v[threadIdx.x]);
#else
  T tmp_q = data_ptr_q[threadIdx.x] + bias_ptr_q[threadIdx.x];
  T tmp_k = data_ptr_k[threadIdx.x] + bias_ptr_k[threadIdx.x];
  T tmp_v = data_ptr_v[threadIdx.x] + bias_ptr_v[threadIdx.x];
#endif

  int target_id = batch_id * (seq_len * head_num * size_per_head) +
                  head_id * seq_len * size_per_head +
                  word_start_id * size_per_head + id_in_head;

  q_buf_[target_id] = tmp_q;
  k_buf_[target_id] = tmp_k;
  v_buf_[target_id] = tmp_v;
}

// Keep to compare performance
template <typename T>
__global__ void add_QKV_V2(const T *Q, const T *K, const T *V, T *q_buf_,
                           T *k_buf_, T *v_buf_, const T *bias_Q,
                           const T *bias_K, const T *bias_V, int batch_size,
                           int seq_len, int head_num, int size_per_head,
                           const int word_per_block) {
  const T *data_ptr;
  T *buf_ptr;
  const T *bias_ptr;

  int m = batch_size * seq_len;
  int n = head_num * size_per_head;

  int qkv_id = blockIdx.x * word_per_block / m;
  int row_offset = (blockIdx.x * word_per_block % m) * n;

  if (qkv_id == 0) {
    data_ptr = Q + row_offset;
    buf_ptr = q_buf_;
    bias_ptr = bias_Q;
  } else if (qkv_id == 1) {
    data_ptr = K + row_offset;
    buf_ptr = k_buf_;
    bias_ptr = bias_K;
  } else {
    data_ptr = V + row_offset;
    buf_ptr = v_buf_;
    bias_ptr = bias_V;
  }

  int batch_id = (blockIdx.x * word_per_block % m) / seq_len;
  int head_id = threadIdx.x / size_per_head;
  int id_in_head = threadIdx.x % size_per_head;
  int word_start_id = (blockIdx.x * word_per_block) % seq_len;

#if __CUDA_ARCH__ >= 350
  T bias = __ldg(&bias_ptr[threadIdx.x]);
#else
  T bias = bias_ptr[threadIdx.x];
#endif

  for (int i = word_start_id; i < word_start_id + word_per_block; ++i) {
    T tmp = data_ptr[threadIdx.x] + bias;

    int target_id = batch_id * (seq_len * head_num * size_per_head) +
                    head_id * seq_len * size_per_head + i * size_per_head +
                    id_in_head;

    buf_ptr[target_id] = tmp;
    data_ptr += n;
  }
}

template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
                                           const int batch_size,
                                           const int head_num,
197 198
                                           const int seq_len,
                                           const unsigned mask) {
199
  int qk_offset = blockIdx.x * seq_len;
200
  assert(blockDim.x % 32 == 0);
201 202 203 204 205

  __shared__ float s_sum, s_max;

  float qk = threadIdx.x < seq_len
                 ? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
Z
zhaoyuchen2018 已提交
206
                                       bias_qk_[threadIdx.x + qk_offset]))
207 208
                 : 0.0f;
  float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
209 210 211

  float max_val = blockReduceMax<float>(tmp, mask);

212 213 214 215 216
  if (threadIdx.x == 0) s_max = max_val;
  __syncthreads();

  float qk_tmp =
      threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
217
  float sum_val = blockReduceSum<float>(qk_tmp, mask);
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

  if (threadIdx.x == 0) {
    s_sum = sum_val + 1e-6f;
  }
  __syncthreads();

  if (threadIdx.x < seq_len)
    qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}

// For verify result
template <typename T>
__global__ void elt_qk_add(const T *bias_qk, T *qk_buf, int head_num,
                           int seq_len, int size_per_head, int batch_size) {
  int m = batch_size * head_num * seq_len;
  int row_id = blockIdx.x % m;
  int dst_id = row_id * seq_len + threadIdx.x;
  const T *bias_ptr = bias_qk;
#if __CUDA_ARCH__ >= 350
  int tmp_bias = __ldg(&bias_ptr[dst_id]);
#else
  int tmp_bias = bias_ptr[dst_id];
#endif

  qk_buf[dst_id] += tmp_bias;
}

// Compute Q*K->softmax->eltadd
template <typename T>
void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
                      int seq_len, int size_per_head, int batch_size,
                      bool q_trans, bool k_trans, T *q_buf_, T *k_buf_,
                      T *qk_buf_, const T *bias_qk, T alpha, T beta) {
  CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;

  auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
  auto stream = context.stream();

  blas.BatchedGEMM(transA, transB, seq_len, seq_len, size_per_head, alpha,
                   q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
                   seq_len * size_per_head, seq_len * size_per_head);

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
  int grid = batch_size * head_num * seq_len;
  int block = seq_len;

  // Align block to 32, also limit seq_len to max block size.
  PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
                                       "seq_len should <= 1024, "
                                       "but received seq_len is:%d",
                                       seq_len));
  if (seq_len <= 32)
    block = 32;
  else if (seq_len > 32 && seq_len <= 64)
    block = 64;
  else if (seq_len > 64 && seq_len <= 128)
    block = 128;
  else if (seq_len > 128 && seq_len <= 256)
    block = 256;
  else if (seq_len > 256 && seq_len <= 512)
    block = 512;
  else
    block = 1024;
281 282

  softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
283
      qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
}

template <typename T>
__global__ void transpose(T *src, T *dst, const int batch_size,
                          const int seq_len, const int head_num,
                          const int size_per_head) {
  int batch_id = blockIdx.x / (head_num * seq_len);
  int seq_id = blockIdx.x % seq_len;
  int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
  dst[batch_id * (head_num * seq_len * size_per_head) +
      seq_id * head_num * size_per_head + head_id * size_per_head +
      threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}

// Compute QK*V->transpose
template <typename T>
void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num,
                       int seq_len, int size_per_head, int batch_size,
                       bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_,
303
                       T *dst, T alpha, T beta) {
304 305 306 307 308 309 310 311 312 313 314
  int m = batch_size * seq_len;
  int k = head_num * size_per_head;

  auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
  auto stream = context.stream();
  CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;

  blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha,
                   qk_buf_, v_buf_, beta, dst, batch_size * head_num,
                   seq_len * seq_len, seq_len * size_per_head);
315
}
316

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
template <typename T>
inline __device__ T add_func(T a, T b);

template <>
__device__ float add_func<float>(float a, float b) {
  return a + b;
}

template <>
__device__ float2 add_func<float2>(float2 a, float2 b) {
  float2 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  return c;
}

template <>
__device__ float4 add_func<float4>(float4 a, float4 b) {
  float4 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  c.z = a.z + b.z;
  c.w = a.w + b.w;
  return c;
341 342 343
}

template <typename T>
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
__global__ void transpose_qkv_kernel(const int H, const T *input, const T *bias,
                                     T *output) {
  // Input: BxSx3xNxH
  // Bias: 3xSxB
  // Output: 3xBxNxSxH
  int n = threadIdx.y;
  int s = blockIdx.x;
  int b = blockIdx.y;
  int m = blockIdx.z;

  const int N = blockDim.y;
  const int S = gridDim.x;
  const int B = gridDim.y;

  const int NH = N * H;
  const int NHS = NH * S;
  const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
  const int bias_offset = m * NH + n * H;
  const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;

  const int i = threadIdx.x;
  output[out_offset + i] =
      add_func(input[in_offset + i], bias[bias_offset + i]);
}
368

369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
                      const int head_num, const float *input, const float *bias,
                      float *output, cudaStream_t stream) {
  // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
  const dim3 grid(seq_len, batch, 3);
  if (head_size % 4 == 0) {
    const int h = head_size / 4;
    const float4 *input4 = reinterpret_cast<const float4 *>(input);
    const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
    float4 *output4 = reinterpret_cast<float4 *>(output);
    const dim3 block(h, head_num, 1);

    // limit h * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(h * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024 * 4));
    transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
                                                             output4);
  } else if (head_size % 2 == 0) {
    const int h = head_size / 2;
    const float2 *input2 = reinterpret_cast<const float2 *>(input);
    const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
    float2 *output2 = reinterpret_cast<float2 *>(output);
    const dim3 block(h, head_num, 1);
    // limit h * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(h * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024 * 2));
    transpose_qkv_kernel<float2><<<grid, block, 0, stream>>>(h, input2, bias2,
                                                             output2);
  } else {
    const dim3 block(head_size, head_num, 1);
    // limit head_size * head_num to max block size(1024).
    PADDLE_ENFORCE_LE(head_size * head_num, 1024,
                      platform::errors::InvalidArgument(
                          "head_num (%d) * head_size (%d) should <= %d",
                          head_num, head_size, 1024));
    transpose_qkv_kernel<float><<<grid, block, 0, stream>>>(head_size, input,
                                                            bias, output);
  }
}
412

413 414 415 416 417
template <typename T>
void MultiHeadGPUComputeV2(const platform::CUDADeviceContext &dev_ctx,
                           int batch, int seq_len, int head_num, int head_size,
                           T *qkptr, const T *bias_qk_ptr, T *tptr, T alpha,
                           T beta) {
418
  auto stream = dev_ctx.stream();
419 420 421 422 423 424 425 426 427 428 429
  const int tsize = batch * head_num * seq_len * head_size;

  T *qptr = tptr;
  T *kptr = qptr + tsize;
  T *vptr = kptr + tsize;
  // batch gemm stride, softmaxwithscale.
  MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, head_size, batch, false, true,
                      qptr, kptr, qkptr, bias_qk_ptr, alpha, beta);
  // batch gemm stride, transpose.
  MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, head_size, batch, false,
                       false, vptr, qkptr, tptr, T(1.0), beta);
430 431 432
}

template <typename DeviceContext, typename T>
433
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
434 435
 public:
  void Compute(const framework::ExecutionContext &context) const override {
436 437 438 439
    using Tensor = framework::Tensor;
    auto *input = context.Input<framework::Tensor>("Input");
    auto *w = context.Input<framework::Tensor>("W");
    auto *bias = context.Input<framework::Tensor>("Bias");
440 441 442 443

    auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
                                "Cannot find QK");

444 445 446 447
    auto *input_d = input->data<T>();
    auto *w_d = w->data<T>();
    auto *bias_d = bias->data<T>();
    auto *bias_qk_d = bias_qk.data<T>();
448 449 450 451 452
    T scale = static_cast<T>(context.Attr<float>("alpha"));

    int head_number = context.Attr<int>("head_number");
    // compute q*k with eltadd
    auto &device_ctx = context.template device_context<DeviceContext>();
453 454 455 456 457 458 459 460 461 462 463
    // should be (B * S * hidden)
    auto input_dims = input->dims();
    // shouble be (hidden * 3 * all_head_size)
    auto w_dims = w->dims();
    int batch = input_dims[0];
    int seq_len = input_dims[1];
    int hidden = input_dims[2];

    int all_head_size = w_dims[2];
    int head_size = all_head_size / head_number;

464 465 466 467
    auto *out = context.Output<framework::Tensor>("Out");
    out->Resize({batch, seq_len, all_head_size});
    auto *output_d = out->mutable_data<T>(context.GetPlace());

468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
    // (B*S, hidden)
    const Tensor input_matrix =
        framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
    // (hidden, 3 * all_head_size)
    const Tensor w_matrix =
        framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);

    Tensor temp_out_tensor;
    auto temp_out_dims =
        framework::make_ddim({batch, seq_len, 3, head_number, head_size});
    temp_out_tensor.Resize({batch * seq_len, framework::product(temp_out_dims) /
                                                 (batch * seq_len)});
    auto *temp_out_data = temp_out_tensor.mutable_data<T>(context.GetPlace());

    // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(device_ctx);
    blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);

    // temp_out_tensor.Resize(temp_out_dims);

    Tensor multihead_temp_tensor;
    // B * head_number * S * S * 1 + B * S * 3 * N * H
    int scratch_size = batch * head_number * seq_len * seq_len * 1;
    multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
    auto *multihead_temp_data =
        multihead_temp_tensor.mutable_data<T>(context.GetPlace());
    auto *qkptr = multihead_temp_data;
    auto *tptr = multihead_temp_data + scratch_size;

    auto stream = device_ctx.stream();
    // Do the transpose with bias.
    // BxSx3xNxH => tptr: 3xBxNxSxH.
    TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
                     bias_d, tptr, stream);

    MultiHeadGPUComputeV2<T>(device_ctx, batch, seq_len, head_number, head_size,
                             qkptr, bias_qk_d, tptr, scale, T(0.0));

    int grid = batch * head_number * seq_len;
    int block = head_size;
    transpose<T><<<grid, block, 0, stream>>>(tptr, output_d, batch, seq_len,
                                             head_number, head_size);
510 511 512 513 514 515 516 517 518
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    multihead_matmul,
519
    ops::MultiHeadMatMulV2Kernel<paddle::platform::CUDADeviceContext, float>);