bert_encoder_functor.cu 39.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

P
Pei Yang 已提交
15
#include <algorithm>
16

17 18 19 20
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/enforce.h"
21 22
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {

28
// NOTE(chenfeiyu): explicitly use operator+ for float2
29 30
// since float2 is not in namespace phi::funcs, ADL won't help
using phi::funcs::operator+;
31

W
wenbin 已提交
32 33 34 35 36 37 38 39
template <typename T>
__device__ __forceinline__ T local_rsqrt(T num) {
  return rsqrt(static_cast<float>(num));
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); }
#endif

40
template <typename T, int TPB>
41
__device__ inline void LayerNormSmall(T val,
42
                                      const phi::funcs::kvp<T> &thread_data,
43 44 45 46 47 48
                                      const int ld,
                                      const int idx,
                                      const float *bias,
                                      const float *scale,
                                      T *output,
                                      T eps) {
49
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
50 51 52 53 54 55 56 57
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
58
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
59 60 61 62 63 64 65 66 67 68 69
  }
  __syncthreads();

  if (threadIdx.x < ld) {
    const T g(scale[threadIdx.x]);
    const T b(bias[threadIdx.x]);
    output[idx] = g * (val - mu) * rsigma + b;
  }
}

template <typename T, int TPB>
70
__device__ inline void LayerNorm(const phi::funcs::kvp<T> &thread_data,
71 72 73 74 75 76
                                 const int ld,
                                 const int offset,
                                 const float *bias,
                                 const float *scale,
                                 T *output,
                                 T eps) {
77
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
78 79 80 81 82 83 84 85
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
86
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
87 88 89 90 91 92 93 94 95 96 97 98
  }
  __syncthreads();

  for (int i = threadIdx.x; i < ld; i += TPB) {
    const int idx = offset + i;
    const T val = output[idx];
    const T g(scale[i]);
    const T b(bias[i]);
    output[idx] = g * (val - mu) * rsigma + b;
  }
}

99
template <typename T, typename T2, int TPB>
100
__device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
101 102 103 104 105 106
                                  const int ld,
                                  const int offset,
                                  const float2 *bias,
                                  const float2 *scale,
                                  T2 *output,
                                  T eps) {
107
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
108 109 110 111 112 113 114 115
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
116
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
117 118 119 120 121 122 123 124 125 126 127 128 129 130
  }
  __syncthreads();

  for (int i = threadIdx.x; i < ld; i += TPB) {
    const int idx = offset + i;
    T2 val = output[idx];
    const float2 g = scale[i];
    const float2 b = bias[i];
    val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x);
    val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y);
    output[idx] = val;
  }
}

131
template <typename T, unsigned TPB>
132 133 134 135 136 137 138 139
__global__ void EmbEltwiseLayernormKernel(int hidden,
                                          const int64_t *ids,
                                          const float *scale,
                                          const float *bias,
                                          const int64_t *embs,
                                          T *output,
                                          float eps,
                                          int input_num) {
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  cub::Sum pair_sum;
  // blockIdx.x: position in the sequence
  // blockIdx.y: batch
  // gridDim.x: Seq
  // gridDim.y: Batch

  extern __shared__ int64_t array_id[];

  const T rhidden = T(1.f) / T(hidden);
  const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
  if (threadIdx.x == 0) {
    for (int i = 0; i < input_num; ++i) {
      const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
      array_id[i] = ids_p[seq_pos];
    }
  }
  __syncthreads();

  const int64_t out_offset = seq_pos * hidden;

160
  phi::funcs::kvp<T> thread_data(0, 0);
161 162 163 164 165 166 167 168 169 170

#pragma unroll
  for (int it = threadIdx.x; it < hidden; it += TPB) {
    T val = 0;
    for (int i = 0; i < input_num; ++i) {
      val += reinterpret_cast<const T *>(embs[i])[array_id[i] * hidden + it];
    }

    output[out_offset + it] = val;
    const T rhiddenval = rhidden * val;
171 172
    thread_data =
        pair_sum(thread_data, phi::funcs::kvp<T>(rhiddenval, rhiddenval * val));
173 174 175 176
  }
  LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
}

177 178
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: EmbEltwiseLayernormKernel
179
template <>
180 181 182 183 184 185 186 187
__global__ void EmbEltwiseLayernormKernel<half, 256>(int hidden,
                                                     const int64_t *ids,
                                                     const float *scale,
                                                     const float *bias,
                                                     const int64_t *embs,
                                                     half *output,
                                                     float eps,
                                                     int input_num) {
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  cub::Sum pair_sum;
  // blockIdx.x: position in the sequence
  // blockIdx.y: batch
  // gridDim.x: Seq
  // gridDim.y: Batch

  extern __shared__ int64_t array_id[];

  const half rhidden = half(1.f) / half(hidden);
  const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
  if (threadIdx.x == 0) {
    for (int i = 0; i < input_num; ++i) {
      const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
      array_id[i] = ids_p[seq_pos];
    }
  }
  __syncthreads();

  const int64_t out_offset = seq_pos * hidden;

209
  phi::funcs::kvp<half> thread_data(0, 0);
210 211 212 213 214 215 216 217 218 219

#pragma unroll
  for (int it = threadIdx.x; it < hidden; it += 256) {
    half val = 0;
    for (int i = 0; i < input_num; ++i) {
      val += reinterpret_cast<const half *>(embs[i])[array_id[i] * hidden + it];
    }

    output[out_offset + it] = val;
    const half rhiddenval = rhidden * val;
220 221
    thread_data = pair_sum(thread_data,
                           phi::funcs::kvp<half>(rhiddenval, rhiddenval * val));
222
  }
223 224
  LayerNorm<half, 256>(
      thread_data, hidden, out_offset, bias, scale, output, eps);
225 226
#endif
}
227
#endif  // @} End Half kernel: EmbEltwiseLayernormKernel
228

229
template <typename T>
230 231 232 233 234 235 236 237 238 239 240
void EmbEltwiseLayerNormFunctor<T>::operator()(int batch,
                                               int seq_len,
                                               int hidden,
                                               const int64_t *ids,
                                               const float *scale,
                                               const float *bias,
                                               const int64_t *embs,
                                               T *output,
                                               float eps,
                                               int input_num,
                                               gpuStream_t stream) {
241 242 243 244 245 246 247 248 249 250
  const unsigned tpb = 256;
  const dim3 grid(seq_len, batch, 1);
  const dim3 block(tpb, 1, 1);
  int shared_bytes = input_num * sizeof(int64_t);
  EmbEltwiseLayernormKernel<T, tpb><<<grid, block, shared_bytes, stream>>>(
      hidden, ids, scale, bias, embs, output, eps, input_num);
}

template class EmbEltwiseLayerNormFunctor<float>;

251
// device function 'operator()' is not supportted until cuda 10.0
252 253
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
254 255 256 257
template class EmbEltwiseLayerNormFunctor<half>;
#endif

template <typename T>
258 259
__global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
                                        const T *bias_qk_,
260
                                        const int batch_size,
261 262
                                        const int head_num,
                                        const int seq_len,
263 264 265 266
                                        const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

267 268 269 270
  float tmp = threadIdx.x < seq_len
                  ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
                                       bias_qk_[threadIdx.x + qk_offset])
                  : -1e20f;
271
  float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
272

273
  float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
274
  float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
275 276

  if (threadIdx.x < seq_len)
277 278 279
    qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
}

280 281
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__  // @{ Half kernel: SoftmaxKernelWithEltadd
282
template <>
283 284 285 286 287 288
__global__ void SoftmaxKernelWithEltadd<half>(half *qk_buf_,
                                              const half *bias_qk_,
                                              const int batch_size,
                                              const int head_num,
                                              const int seq_len,
                                              const unsigned mask) {
289 290 291 292 293 294 295 296
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float tmp = threadIdx.x < seq_len
                  ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
                                       bias_qk_[threadIdx.x + qk_offset])
                  : -1e20f;
297
  float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
298 299

  float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
300
  float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
301 302 303 304 305

  if (threadIdx.x < seq_len)
    qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
#endif
}
306
#endif  // @} End Half kernel: SoftmaxKernelWithEltadd
307

308
template <typename T>
309 310
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
                                         const T *bias_qk_,
311
                                         const int batch_size,
312 313
                                         const int head_num,
                                         const int seq_len,
314 315 316 317 318
                                         const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  int idx = threadIdx.x;
  assert(blockDim.x % 32 == 0);

319
  float2 tmp = idx < seq_len
320 321
                   ? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
                                             bias_qk_[idx + qk_offset])
322
                   : make_float2(-1e20f, -1e20f);
323
  float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
324 325 326
  float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
                                              __expf(tmp.y - max_val))
                                : make_float2(0.f, 0.f);
327
  float sum_val =
328
      phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
329 330 331

  if (idx < seq_len) {
    qk_buf_[idx + qk_offset] =
332
        phi::funcs::FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
333
  }
334 335
}

336
template <>
337 338 339 340 341 342
__global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
                                                const half2 *bias_qk_,
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
343
// operator "+" of half only suppotted after cuda version 10.0
344
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
345
#if defined(PADDLE_WITH_CUDA) && \
346
    (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
347 348 349 350
  int qk_offset = blockIdx.x * seq_len;
  int idx = threadIdx.x;
  assert(blockDim.x % 32 == 0);

351
  float2 tmp = idx < seq_len
352 353
                   ? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
                                                 bias_qk_[idx + qk_offset])
354
                   : make_float2(-1e20f, -1e20f);
355
  float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
356 357 358
  float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
                                              __expf(tmp.y - max_val))
                                : make_float2(0.f, 0.f);
359
  float sum_val =
360
      phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
361 362

  if (idx < seq_len) {
363 364
    qk_buf_[idx + qk_offset] =
        phi::funcs::FloatsToPair<half2>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
365 366 367 368
  }
#endif
}

369
template <typename T>
370 371
__global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
                                                const T *bias_qk,
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  T stride_max = -1e20f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    stride_max = qk_buf[threadIdx.x + i + qk_offset] +
                             bias_qk[threadIdx.x + i + qk_offset] >
                         stride_max
                     ? qk_buf[threadIdx.x + i + qk_offset] +
                           bias_qk[threadIdx.x + i + qk_offset]
                     : stride_max;
  }
388
  T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask);
389 390 391 392 393 394

  T stride_sum = 0.f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
                         bias_qk[threadIdx.x + i + qk_offset] - max_val);
  }
395
  T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask);
396 397 398 399 400 401 402 403 404 405 406 407

  for (int i = 0; i < seq_len; i += blockDim.x) {
    qk_buf[threadIdx.x + i + qk_offset] =
        (T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
                   bias_qk[threadIdx.x + i + qk_offset] - max_val) /
            sum_val);
  }
}

// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__  // @{ Half kernel: SoftmaxKernelWithEltadd
template <>
408 409 410 411 412 413
__global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
                                                const half *bias_qk,
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
414 415 416 417 418 419 420 421 422 423
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float stride_max = -1e20f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                   bias_qk[threadIdx.x + i + qk_offset]);
    stride_max = tmp > stride_max ? tmp : stride_max;
  }
424
  float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask);
425 426 427 428 429 430 431

  float stride_sum = 0.f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                   bias_qk[threadIdx.x + i + qk_offset]);
    stride_sum += __expf(tmp - max_val);
  }
432
  float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask);
433 434 435 436 437 438 439 440 441 442 443 444 445

  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp =
        __expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                  bias_qk[threadIdx.x + i + qk_offset]) -
               max_val);
    qk_buf[threadIdx.x + i + qk_offset] = (half)(tmp / sum_val);
  }
#endif
}
#endif  // @} End Half kernel: SoftmaxKernelWithEltadd

template <typename T>
446 447
__global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
                                                 const T *bias_qk_,
448 449 450 451 452 453 454 455 456
                                                 const int batch_size,
                                                 const int head_num,
                                                 const int seq_len,
                                                 const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float2 stride_max = make_float2(-1e20f, -1e20f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
457 458
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
459 460 461
    stride_max.x = max(stride_max.x, cur.x);
    stride_max.y = max(stride_max.y, cur.y);
  }
462
  float max_val =
463
      phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
464 465 466

  float2 stride_sum = make_float2(0.f, 0.f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
467 468
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
469 470 471 472 473
    stride_sum.x += __expf(cur.x - max_val);
    stride_sum.y += __expf(cur.y - max_val);
  }

  float sum_val =
474
      phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
475
      1e-6f;
476 477

  for (int i = 0; i < seq_len; i += blockDim.x) {
478 479 480
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
    qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
481 482 483 484 485
        __expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
  }
}

template <>
486 487 488 489 490 491
__global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
                                                 const half2 *bias_qk_,
                                                 const int batch_size,
                                                 const int head_num,
                                                 const int seq_len,
                                                 const unsigned mask) {
492 493 494 495 496 497 498 499 500 501
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
    (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)

  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float2 stride_max = make_float2(-1e20f, -1e20f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
502
    float2 cur =
503 504
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
505 506 507
    stride_max.x = max(stride_max.x, cur.x);
    stride_max.y = max(stride_max.y, cur.y);
  }
508
  float max_val =
509
      phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
510 511 512

  float2 stride_sum = make_float2(0.f, 0.f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
513
    float2 cur =
514 515
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
516 517 518 519 520
    stride_sum.x += __expf(cur.x - max_val);
    stride_sum.y += __expf(cur.y - max_val);
  }

  float sum_val =
521
      phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
522
      1e-6f;
523 524

  for (int i = 0; i < seq_len; i += blockDim.x) {
525
    float2 cur =
526 527 528
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
    qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<half2>(
529 530 531 532 533
        __expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
  }
#endif
}

534
template <typename T>
L
Leo Chen 已提交
535
inline void MatMulWithHeadQK(const phi::GPUContext &context,
536 537 538 539 540 541 542 543 544 545 546 547
                             int head_num,
                             int seq_len,
                             int size_per_head,
                             int batch_size,
                             bool q_trans,
                             bool k_trans,
                             T *q_buf_,
                             T *k_buf_,
                             T *qk_buf_,
                             const T *bias_qk,
                             T alpha,
                             T beta) {
548 549 550 551
  CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;

  typedef typename CUDATypeTraits<T>::TYPE run_type;
L
Leo Chen 已提交
552
  auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
553 554
  auto stream = context.stream();

555 556 557 558 559 560 561 562 563 564 565 566 567
  blas.BatchedGEMM(transA,
                   transB,
                   seq_len,
                   seq_len,
                   size_per_head,
                   static_cast<run_type>(alpha),
                   reinterpret_cast<run_type *>(q_buf_),
                   reinterpret_cast<run_type *>(k_buf_),
                   static_cast<run_type>(beta),
                   reinterpret_cast<run_type *>(qk_buf_),
                   batch_size * head_num,
                   seq_len * size_per_head,
                   seq_len * size_per_head);
568

569 570 571 572 573 574 575 576 577 578
  if (seq_len <= 1024) {
    int grid = batch_size * head_num * seq_len;
    int block = seq_len;

    // Align block to 32, also limit seq_len to max block size.
    if (seq_len % 2 == 0) {
      block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32;
      if (std::is_same<T, float>::value) {
        SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
            reinterpret_cast<float2 *>(qk_buf_),
579 580 581 582 583
            reinterpret_cast<const float2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
584 585 586
      } else {
        SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
            reinterpret_cast<__half2 *>(qk_buf_),
587 588 589 590 591
            reinterpret_cast<const __half2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
592
      }
593
    } else {
594 595 596
      block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32;
      SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
          qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
597 598
    }
  } else {
599 600 601 602 603 604
    int grid = batch_size * head_num * seq_len;
    int block = 512;
    if (seq_len % 2 == 0) {
      if (std::is_same<T, float>::value) {
        SoftmaxKernelWithEltaddForLarge2<float2><<<grid, block, 0, stream>>>(
            reinterpret_cast<float2 *>(qk_buf_),
605 606 607 608 609
            reinterpret_cast<const float2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
610 611 612
      } else {
        SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
            reinterpret_cast<__half2 *>(qk_buf_),
613 614 615 616 617
            reinterpret_cast<const __half2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
618 619 620 621 622
      }
    } else {
      SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
          qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
    }
623
  }
624 625 626
}

template <typename T>
L
Leo Chen 已提交
627
inline void MatMulWithHeadQKV(const phi::GPUContext &context,
628 629 630 631 632 633 634 635 636 637
                              int head_num,
                              int seq_len,
                              int size_per_head,
                              int batch_size,
                              bool qk_trans,
                              bool v_trans,
                              T *v_buf_,
                              const T *qk_buf_,
                              T *dst,
                              T alpha,
638 639 640 641 642
                              T beta) {
  int m = batch_size * seq_len;
  int k = head_num * size_per_head;

  typedef typename CUDATypeTraits<T>::TYPE run_type;
L
Leo Chen 已提交
643
  auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
644 645 646 647
  auto stream = context.stream();
  CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;

648 649 650 651 652 653 654 655 656 657 658 659 660
  blas.BatchedGEMM(transA,
                   transB,
                   seq_len,
                   size_per_head,
                   seq_len,
                   static_cast<run_type>(alpha),
                   reinterpret_cast<const run_type *>(qk_buf_),
                   reinterpret_cast<run_type *>(v_buf_),
                   static_cast<run_type>(beta),
                   reinterpret_cast<run_type *>(dst),
                   batch_size * head_num,
                   seq_len * seq_len,
                   seq_len * size_per_head);
661 662 663
}

template <typename T>
L
Leo Chen 已提交
664 665 666 667 668 669 670 671 672 673
void MultiHeadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
                                               int batch,
                                               int seq_len,
                                               int head_num,
                                               int head_size,
                                               T *qkptr,
                                               const T *bias_qk_ptr,
                                               T *tptr,
                                               T alpha,
                                               T beta) {
674 675 676 677 678 679 680
  auto stream = dev_ctx.stream();
  const int tsize = batch * head_num * seq_len * head_size;

  T *qptr = tptr;
  T *kptr = qptr + tsize;
  T *vptr = kptr + tsize;
  // batch gemm stride, softmaxwithscale.
681 682 683 684 685 686 687 688 689 690 691 692 693
  MatMulWithHeadQK<T>(dev_ctx,
                      head_num,
                      seq_len,
                      head_size,
                      batch,
                      false,
                      true,
                      qptr,
                      kptr,
                      qkptr,
                      bias_qk_ptr,
                      alpha,
                      beta);
694
  // batch gemm stride, transpose.
695 696 697 698 699 700 701 702 703 704 705 706
  MatMulWithHeadQKV<T>(dev_ctx,
                       head_num,
                       seq_len,
                       head_size,
                       batch,
                       false,
                       false,
                       vptr,
                       qkptr,
                       tptr,
                       T(1.0),
                       beta);
707 708 709 710
}

template class MultiHeadGPUComputeFunctor<float>;

711
// device function 'operator()' is not supportted until cuda 10.0
712
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
713
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
714 715 716 717
template class MultiHeadGPUComputeFunctor<half>;
#endif

template <typename T, unsigned TPB>
718 719 720 721 722 723 724
__global__ void SkipLayerNormSmallKernel(int num,
                                         int hidden,
                                         const T *input1,
                                         const T *input2,
                                         T *output,
                                         const float *scale,
                                         const float *bias,
725 726 727 728
                                         float eps) {
  const T rld = T(1) / T(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
729
  phi::funcs::kvp<T> thread_data(0, 0);
730 731 732 733 734
  const int idx = offset + threadIdx.x;
  T val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const T rldval = rld * val;
735
    thread_data =
736
        pair_sum(thread_data, phi::funcs::kvp<T>(rldval, rldval * val));
737
  }
738 739
  LayerNormSmall<T, TPB>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
740 741
}

742 743
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormSmallKernel
744
template <>
745 746 747 748 749 750 751 752
__global__ void SkipLayerNormSmallKernel<half, 32>(int num,
                                                   int hidden,
                                                   const half *input1,
                                                   const half *input2,
                                                   half *output,
                                                   const float *scale,
                                                   const float *bias,
                                                   float eps) {
753 754 755 756
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
757
  phi::funcs::kvp<half> thread_data(0, 0);
758 759 760 761 762
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
763
    thread_data =
764
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
765
  }
766 767
  LayerNormSmall<half, 32>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
768 769 770 771
#endif
}

template <>
772 773 774 775 776 777 778 779
__global__ void SkipLayerNormSmallKernel<half, 128>(int num,
                                                    int hidden,
                                                    const half *input1,
                                                    const half *input2,
                                                    half *output,
                                                    const float *scale,
                                                    const float *bias,
                                                    float eps) {
780 781 782 783
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
784
  phi::funcs::kvp<half> thread_data(0, 0);
785 786 787 788 789
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
790
    thread_data =
791
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
792
  }
793 794
  LayerNormSmall<half, 128>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
795 796 797 798
#endif
}

template <>
799 800 801 802 803 804 805 806
__global__ void SkipLayerNormSmallKernel<half, 384>(int num,
                                                    int hidden,
                                                    const half *input1,
                                                    const half *input2,
                                                    half *output,
                                                    const float *scale,
                                                    const float *bias,
                                                    float eps) {
807 808 809 810
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
811
  phi::funcs::kvp<half> thread_data(0, 0);
812 813 814 815 816
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
817
    thread_data =
818
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
819
  }
820 821
  LayerNormSmall<half, 384>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
822 823
#endif
}
824
#endif  // @} End Half kernel: SkipLayerNormSmallKernel
825

826
template <typename T, unsigned TPB>
827 828 829 830 831 832 833
__global__ void SkipLayerNormKernel(int num,
                                    int hidden,
                                    const T *input1,
                                    const T *input2,
                                    T *output,
                                    const float *scale,
                                    const float *bias,
834 835 836 837
                                    float eps) {
  const T rld = T(1) / T(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
838
  phi::funcs::kvp<T> thread_data(0, 0);
839 840 841 842 843

  for (int it = threadIdx.x; it < hidden; it += TPB) {
    const int idx = offset + it;
    const T val = input1[idx] + input2[idx];
    const T rldval = rld * val;
844
    thread_data =
845
        pair_sum(thread_data, phi::funcs::kvp<T>(rldval, rldval * val));
846 847 848 849 850
    output[idx] = val;
  }
  LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}

851 852
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormKernel
853
template <>
854 855
__global__ void SkipLayerNormKernel<half, 256>(int num,
                                               int hidden,
856
                                               const half *input1,
857 858
                                               const half *input2,
                                               half *output,
859
                                               const float *scale,
860 861
                                               const float *bias,
                                               float eps) {
862 863 864 865
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
866
  phi::funcs::kvp<half> thread_data(0, 0);
867 868 869 870 871

  for (int it = threadIdx.x; it < hidden; it += 256) {
    const int idx = offset + it;
    const half val = input1[idx] + input2[idx];
    const half rldval = rld * val;
872
    thread_data =
873
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
874 875 876 877 878
    output[idx] = val;
  }
  LayerNorm<half, 256>(thread_data, hidden, offset, bias, scale, output, eps);
#endif
}
879
#endif  // @} End Half kernel: SkipLayerNormKernel
880

881
template <typename T, typename T2, unsigned TPB>
882 883 884 885 886 887 888
__global__ void SkipLayerNormKernel2(int num,
                                     int hidden,
                                     const T2 *input1,
                                     const T2 *input2,
                                     T2 *output,
                                     const float2 *scale,
                                     const float2 *bias,
889 890 891 892
                                     float eps) {
  const T rld = T(0.5f / hidden);  // because hidden is hidden/2
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
893
  phi::funcs::kvp<T> thread_data(0, 0);
894 895 896 897 898

  for (int it = threadIdx.x; it < hidden; it += TPB) {
    const int idx = offset + it;
    const T2 val2 = input1[idx] + input2[idx];
    thread_data = pair_sum(
899
        thread_data,
900 901
        phi::funcs::kvp<T>(rld * (val2.x + val2.y),
                           rld * val2.x * val2.x + rld * val2.y * val2.y));
902 903 904 905 906
    output[idx] = val2;
  }
  LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}

907 908
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormKernel2
909
template <>
910 911 912 913 914 915 916 917
__global__ void SkipLayerNormKernel2<half, half2, 256>(int num,
                                                       int hidden,
                                                       const half2 *input1,
                                                       const half2 *input2,
                                                       half2 *output,
                                                       const float2 *scale,
                                                       const float2 *bias,
                                                       float eps) {
918 919 920 921 922
// operator "+" of half only suppotted after cuda version 10.0
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
  const half rld = half(0.5f / hidden);  // because hidden is hidden/2
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
923
  phi::funcs::kvp<half> thread_data(0, 0);
924 925 926 927 928

  for (int it = threadIdx.x; it < hidden; it += 256) {
    const int idx = offset + it;
    const half2 val2 = input1[idx] + input2[idx];
    thread_data = pair_sum(
929
        thread_data,
930 931
        phi::funcs::kvp<half>(rld * (val2.x + val2.y),
                              rld * val2.x * val2.x + rld * val2.y * val2.y));
932 933
    output[idx] = val2;
  }
934 935
  LayerNorm2<half, half2, 256>(
      thread_data, hidden, offset, bias, scale, output, eps);
936 937
#endif
}
938
#endif  // @} End Half kernel: SkipLayerNormKernel2
939

940
template <typename T>
941 942 943 944 945 946 947 948 949
void SkipLayerNormFunctor<T>::operator()(const int num,
                                         const int hidden,
                                         const T *input1,
                                         const T *input2,
                                         const float *scale,
                                         const float *bias,
                                         T *output,
                                         T eps,
                                         gpuStream_t stream) {
950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
  int block = num / hidden;
  if (hidden <= 32) {
    const int threads = 32;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else if (hidden <= 128) {
    const int threads = 128;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else if (hidden == 384) {
    const int threads = 384;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else {
    const int threads = 256;
965 966
    if (hidden % 2 == 0) {
      if (std::is_same<T, float>::value) {
967 968
        SkipLayerNormKernel2<float, float2, threads>
            <<<block, threads, 0, stream>>>(
969 970 971
                num,
                hidden / 2,
                reinterpret_cast<const float2 *>(input1),
972 973 974
                reinterpret_cast<const float2 *>(input2),
                reinterpret_cast<float2 *>(output),
                reinterpret_cast<const float2 *>(scale),
975 976
                reinterpret_cast<const float2 *>(bias),
                eps);
977 978
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__
979
      } else if (std::is_same<T, __half>::value) {
980 981
        SkipLayerNormKernel2<__half, __half2, threads>
            <<<block, threads, 0, stream>>>(
982 983 984
                num,
                hidden / 2,
                reinterpret_cast<const __half2 *>(input1),
985 986 987
                reinterpret_cast<const __half2 *>(input2),
                reinterpret_cast<__half2 *>(output),
                reinterpret_cast<const float2 *>(scale),
988 989
                reinterpret_cast<const float2 *>(bias),
                eps);
990
#endif
991 992 993 994 995 996 997 998
      } else {
        assert(false);
        // should not be here
      }
    } else {
      SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
          num, hidden, input1, input2, output, scale, bias, eps);
    }
999 1000 1001 1002 1003
  }
}

template class SkipLayerNormFunctor<float>;

1004
// device function 'operator()' is not supportted until cuda 10.0
1005
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
1006
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
1007 1008 1009 1010 1011 1012
template class SkipLayerNormFunctor<half>;
#endif

}  // namespace math
}  // namespace operators
}  // namespace paddle