bert_encoder_functor.cu 39.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

P
Pei Yang 已提交
15
#include <algorithm>
16
#include <type_traits>
17

18 19 20 21
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/enforce.h"
22 23
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
24 25 26 27 28

namespace paddle {
namespace operators {
namespace math {

29
// NOTE(chenfeiyu): explicitly use operator+ for float2
30 31
// since float2 is not in namespace phi::funcs, ADL won't help
using phi::funcs::operator+;
32

W
wenbin 已提交
33 34 35 36 37 38 39 40
template <typename T>
__device__ __forceinline__ T local_rsqrt(T num) {
  return rsqrt(static_cast<float>(num));
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); }
#endif

41
template <typename T, int TPB>
42
__device__ inline void LayerNormSmall(T val,
43
                                      const phi::funcs::kvp<T> &thread_data,
44 45
                                      const int ld,
                                      const int idx,
46 47
                                      const T *bias,
                                      const T *scale,
48 49
                                      T *output,
                                      T eps) {
50
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
51 52 53 54 55 56 57 58
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
59
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
60 61 62 63 64 65 66 67 68 69 70
  }
  __syncthreads();

  if (threadIdx.x < ld) {
    const T g(scale[threadIdx.x]);
    const T b(bias[threadIdx.x]);
    output[idx] = g * (val - mu) * rsigma + b;
  }
}

template <typename T, int TPB>
71
__device__ inline void LayerNorm(const phi::funcs::kvp<T> &thread_data,
72 73
                                 const int ld,
                                 const int offset,
74 75
                                 const T *bias,
                                 const T *scale,
76 77
                                 T *output,
                                 T eps) {
78
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
79 80 81 82 83 84 85 86
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
87
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
88 89 90 91 92 93 94 95 96 97 98 99
  }
  __syncthreads();

  for (int i = threadIdx.x; i < ld; i += TPB) {
    const int idx = offset + i;
    const T val = output[idx];
    const T g(scale[i]);
    const T b(bias[i]);
    output[idx] = g * (val - mu) * rsigma + b;
  }
}

100
template <typename T, typename T2, int TPB>
101
__device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
102 103
                                  const int ld,
                                  const int offset,
104 105
                                  const T2 *bias,
                                  const T2 *scale,
106 107
                                  T2 *output,
                                  T eps) {
108
  using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
109 110 111 112 113 114 115 116
  __shared__ typename BlockReduce::TempStorage temp_storage;
  __shared__ T mu;      // mean
  __shared__ T rsigma;  // 1 / std.dev.

  const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());

  if (threadIdx.x == 0) {
    mu = sum_kv.key;
W
wenbin 已提交
117
    rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
118 119 120 121 122 123
  }
  __syncthreads();

  for (int i = threadIdx.x; i < ld; i += TPB) {
    const int idx = offset + i;
    T2 val = output[idx];
124 125
    const T2 g = scale[i];
    const T2 b = bias[i];
126 127 128 129 130 131
    val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x);
    val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y);
    output[idx] = val;
  }
}

132
template <typename T, unsigned TPB>
133 134
__global__ void EmbEltwiseLayernormKernel(int hidden,
                                          const int64_t *ids,
135 136
                                          const T *scale,
                                          const T *bias,
137 138
                                          const int64_t *embs,
                                          T *output,
139
                                          T eps,
140
                                          int input_num) {
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  cub::Sum pair_sum;
  // blockIdx.x: position in the sequence
  // blockIdx.y: batch
  // gridDim.x: Seq
  // gridDim.y: Batch

  extern __shared__ int64_t array_id[];

  const T rhidden = T(1.f) / T(hidden);
  const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
  if (threadIdx.x == 0) {
    for (int i = 0; i < input_num; ++i) {
      const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
      array_id[i] = ids_p[seq_pos];
    }
  }
  __syncthreads();

  const int64_t out_offset = seq_pos * hidden;

161
  phi::funcs::kvp<T> thread_data(0, 0);
162 163 164 165 166 167 168 169 170 171

#pragma unroll
  for (int it = threadIdx.x; it < hidden; it += TPB) {
    T val = 0;
    for (int i = 0; i < input_num; ++i) {
      val += reinterpret_cast<const T *>(embs[i])[array_id[i] * hidden + it];
    }

    output[out_offset + it] = val;
    const T rhiddenval = rhidden * val;
172 173
    thread_data =
        pair_sum(thread_data, phi::funcs::kvp<T>(rhiddenval, rhiddenval * val));
174 175 176 177
  }
  LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
}

178 179
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: EmbEltwiseLayernormKernel
180
template <>
181 182
__global__ void EmbEltwiseLayernormKernel<half, 256>(int hidden,
                                                     const int64_t *ids,
183 184
                                                     const half *scale,
                                                     const half *bias,
185 186
                                                     const int64_t *embs,
                                                     half *output,
187
                                                     half eps,
188
                                                     int input_num) {
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  cub::Sum pair_sum;
  // blockIdx.x: position in the sequence
  // blockIdx.y: batch
  // gridDim.x: Seq
  // gridDim.y: Batch

  extern __shared__ int64_t array_id[];

  const half rhidden = half(1.f) / half(hidden);
  const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
  if (threadIdx.x == 0) {
    for (int i = 0; i < input_num; ++i) {
      const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
      array_id[i] = ids_p[seq_pos];
    }
  }
  __syncthreads();

  const int64_t out_offset = seq_pos * hidden;

210
  phi::funcs::kvp<half> thread_data(0, 0);
211 212 213 214 215 216 217 218 219 220

#pragma unroll
  for (int it = threadIdx.x; it < hidden; it += 256) {
    half val = 0;
    for (int i = 0; i < input_num; ++i) {
      val += reinterpret_cast<const half *>(embs[i])[array_id[i] * hidden + it];
    }

    output[out_offset + it] = val;
    const half rhiddenval = rhidden * val;
221 222
    thread_data = pair_sum(thread_data,
                           phi::funcs::kvp<half>(rhiddenval, rhiddenval * val));
223
  }
224 225
  LayerNorm<half, 256>(
      thread_data, hidden, out_offset, bias, scale, output, eps);
226 227
#endif
}
228
#endif  // @} End Half kernel: EmbEltwiseLayernormKernel
229

230
template <typename T>
231 232 233 234
void EmbEltwiseLayerNormFunctor<T>::operator()(int batch,
                                               int seq_len,
                                               int hidden,
                                               const int64_t *ids,
235 236
                                               const T *scale,
                                               const T *bias,
237 238 239 240 241
                                               const int64_t *embs,
                                               T *output,
                                               float eps,
                                               int input_num,
                                               gpuStream_t stream) {
242 243 244 245 246 247 248 249 250 251
  const unsigned tpb = 256;
  const dim3 grid(seq_len, batch, 1);
  const dim3 block(tpb, 1, 1);
  int shared_bytes = input_num * sizeof(int64_t);
  EmbEltwiseLayernormKernel<T, tpb><<<grid, block, shared_bytes, stream>>>(
      hidden, ids, scale, bias, embs, output, eps, input_num);
}

template class EmbEltwiseLayerNormFunctor<float>;

252
// device function 'operator()' is not supportted until cuda 10.0
253 254
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
255 256 257 258
template class EmbEltwiseLayerNormFunctor<half>;
#endif

template <typename T>
259 260
__global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
                                        const T *bias_qk_,
261
                                        const int batch_size,
262 263
                                        const int head_num,
                                        const int seq_len,
264 265 266 267
                                        const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

268 269 270 271
  float tmp = threadIdx.x < seq_len
                  ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
                                       bias_qk_[threadIdx.x + qk_offset])
                  : -1e20f;
272
  float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
273

274
  float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
275
  float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
276 277

  if (threadIdx.x < seq_len)
278 279 280
    qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
}

281 282
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__  // @{ Half kernel: SoftmaxKernelWithEltadd
283
template <>
284 285 286 287 288 289
__global__ void SoftmaxKernelWithEltadd<half>(half *qk_buf_,
                                              const half *bias_qk_,
                                              const int batch_size,
                                              const int head_num,
                                              const int seq_len,
                                              const unsigned mask) {
290 291 292 293 294 295 296 297
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float tmp = threadIdx.x < seq_len
                  ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
                                       bias_qk_[threadIdx.x + qk_offset])
                  : -1e20f;
298
  float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
299 300

  float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
301
  float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
302 303 304 305 306

  if (threadIdx.x < seq_len)
    qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
#endif
}
307
#endif  // @} End Half kernel: SoftmaxKernelWithEltadd
308

309
template <typename T>
310 311
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
                                         const T *bias_qk_,
312
                                         const int batch_size,
313 314
                                         const int head_num,
                                         const int seq_len,
315 316 317 318 319
                                         const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  int idx = threadIdx.x;
  assert(blockDim.x % 32 == 0);

320
  float2 tmp = idx < seq_len
321 322
                   ? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
                                             bias_qk_[idx + qk_offset])
323
                   : make_float2(-1e20f, -1e20f);
324
  float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
325 326 327
  float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
                                              __expf(tmp.y - max_val))
                                : make_float2(0.f, 0.f);
328
  float sum_val =
329
      phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
330 331 332

  if (idx < seq_len) {
    qk_buf_[idx + qk_offset] =
333
        phi::funcs::FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
334
  }
335 336
}

337
template <>
338 339 340 341 342 343
__global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
                                                const half2 *bias_qk_,
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
344
// operator "+" of half only suppotted after cuda version 10.0
345
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
346
#if defined(PADDLE_WITH_CUDA) && \
347
    (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
348 349 350 351
  int qk_offset = blockIdx.x * seq_len;
  int idx = threadIdx.x;
  assert(blockDim.x % 32 == 0);

352
  float2 tmp = idx < seq_len
353 354
                   ? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
                                                 bias_qk_[idx + qk_offset])
355
                   : make_float2(-1e20f, -1e20f);
356
  float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
357 358 359
  float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
                                              __expf(tmp.y - max_val))
                                : make_float2(0.f, 0.f);
360
  float sum_val =
361
      phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
362 363

  if (idx < seq_len) {
364 365
    qk_buf_[idx + qk_offset] =
        phi::funcs::FloatsToPair<half2>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
366 367 368 369
  }
#endif
}

370
template <typename T>
371 372
__global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
                                                const T *bias_qk,
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  T stride_max = -1e20f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    stride_max = qk_buf[threadIdx.x + i + qk_offset] +
                             bias_qk[threadIdx.x + i + qk_offset] >
                         stride_max
                     ? qk_buf[threadIdx.x + i + qk_offset] +
                           bias_qk[threadIdx.x + i + qk_offset]
                     : stride_max;
  }
389
  T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask);
390 391 392 393 394 395

  T stride_sum = 0.f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
                         bias_qk[threadIdx.x + i + qk_offset] - max_val);
  }
396
  T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask);
397 398 399 400 401 402 403 404 405 406 407 408

  for (int i = 0; i < seq_len; i += blockDim.x) {
    qk_buf[threadIdx.x + i + qk_offset] =
        (T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
                   bias_qk[threadIdx.x + i + qk_offset] - max_val) /
            sum_val);
  }
}

// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__  // @{ Half kernel: SoftmaxKernelWithEltadd
template <>
409 410 411 412 413 414
__global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
                                                const half *bias_qk,
                                                const int batch_size,
                                                const int head_num,
                                                const int seq_len,
                                                const unsigned mask) {
415 416 417 418 419 420 421 422 423 424
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float stride_max = -1e20f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                   bias_qk[threadIdx.x + i + qk_offset]);
    stride_max = tmp > stride_max ? tmp : stride_max;
  }
425
  float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask);
426 427 428 429 430 431 432

  float stride_sum = 0.f;
  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                   bias_qk[threadIdx.x + i + qk_offset]);
    stride_sum += __expf(tmp - max_val);
  }
433
  float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask);
434 435 436 437 438 439 440 441 442 443 444 445 446

  for (int i = 0; i < seq_len; i += blockDim.x) {
    float tmp =
        __expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
                                  bias_qk[threadIdx.x + i + qk_offset]) -
               max_val);
    qk_buf[threadIdx.x + i + qk_offset] = (half)(tmp / sum_val);
  }
#endif
}
#endif  // @} End Half kernel: SoftmaxKernelWithEltadd

template <typename T>
447 448
__global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
                                                 const T *bias_qk_,
449 450 451 452 453 454 455 456 457
                                                 const int batch_size,
                                                 const int head_num,
                                                 const int seq_len,
                                                 const unsigned mask) {
  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float2 stride_max = make_float2(-1e20f, -1e20f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
458 459
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
460 461 462
    stride_max.x = max(stride_max.x, cur.x);
    stride_max.y = max(stride_max.y, cur.y);
  }
463
  float max_val =
464
      phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
465 466 467

  float2 stride_sum = make_float2(0.f, 0.f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
468 469
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
470 471 472 473 474
    stride_sum.x += __expf(cur.x - max_val);
    stride_sum.y += __expf(cur.y - max_val);
  }

  float sum_val =
475
      phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
476
      1e-6f;
477 478

  for (int i = 0; i < seq_len; i += blockDim.x) {
479 480 481
    float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
                                         bias_qk_[threadIdx.x + i + qk_offset]);
    qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
482 483 484 485 486
        __expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
  }
}

template <>
487 488 489 490 491 492
__global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
                                                 const half2 *bias_qk_,
                                                 const int batch_size,
                                                 const int head_num,
                                                 const int seq_len,
                                                 const unsigned mask) {
493 494 495 496 497 498 499 500 501 502
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
    (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)

  int qk_offset = blockIdx.x * seq_len;
  assert(blockDim.x % 32 == 0);

  float2 stride_max = make_float2(-1e20f, -1e20f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
503
    float2 cur =
504 505
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
506 507 508
    stride_max.x = max(stride_max.x, cur.x);
    stride_max.y = max(stride_max.y, cur.y);
  }
509
  float max_val =
510
      phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
511 512 513

  float2 stride_sum = make_float2(0.f, 0.f);
  for (int i = 0; i < seq_len; i += blockDim.x) {
514
    float2 cur =
515 516
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
517 518 519 520 521
    stride_sum.x += __expf(cur.x - max_val);
    stride_sum.y += __expf(cur.y - max_val);
  }

  float sum_val =
522
      phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
523
      1e-6f;
524 525

  for (int i = 0; i < seq_len; i += blockDim.x) {
526
    float2 cur =
527 528 529
        phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
                                    bias_qk_[threadIdx.x + i + qk_offset]);
    qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<half2>(
530 531 532 533 534
        __expf(cur.x - max_val) / sum_val, __expf(cur.y - max_val) / sum_val);
  }
#endif
}

535
template <typename T>
L
Leo Chen 已提交
536
inline void MatMulWithHeadQK(const phi::GPUContext &context,
537 538 539 540 541 542 543 544 545 546 547 548
                             int head_num,
                             int seq_len,
                             int size_per_head,
                             int batch_size,
                             bool q_trans,
                             bool k_trans,
                             T *q_buf_,
                             T *k_buf_,
                             T *qk_buf_,
                             const T *bias_qk,
                             T alpha,
                             T beta) {
549 550 551 552
  CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;

  typedef typename CUDATypeTraits<T>::TYPE run_type;
L
Leo Chen 已提交
553
  auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
554 555
  auto stream = context.stream();

556 557 558 559 560 561 562 563 564 565 566 567 568
  blas.BatchedGEMM(transA,
                   transB,
                   seq_len,
                   seq_len,
                   size_per_head,
                   static_cast<run_type>(alpha),
                   reinterpret_cast<run_type *>(q_buf_),
                   reinterpret_cast<run_type *>(k_buf_),
                   static_cast<run_type>(beta),
                   reinterpret_cast<run_type *>(qk_buf_),
                   batch_size * head_num,
                   seq_len * size_per_head,
                   seq_len * size_per_head);
569

570 571 572 573 574 575 576 577 578 579
  if (seq_len <= 1024) {
    int grid = batch_size * head_num * seq_len;
    int block = seq_len;

    // Align block to 32, also limit seq_len to max block size.
    if (seq_len % 2 == 0) {
      block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32;
      if (std::is_same<T, float>::value) {
        SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
            reinterpret_cast<float2 *>(qk_buf_),
580 581 582 583 584
            reinterpret_cast<const float2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
585 586 587
      } else {
        SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
            reinterpret_cast<__half2 *>(qk_buf_),
588 589 590 591 592
            reinterpret_cast<const __half2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
593
      }
594
    } else {
595 596 597
      block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32;
      SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
          qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
598 599
    }
  } else {
600 601 602 603 604 605
    int grid = batch_size * head_num * seq_len;
    int block = 512;
    if (seq_len % 2 == 0) {
      if (std::is_same<T, float>::value) {
        SoftmaxKernelWithEltaddForLarge2<float2><<<grid, block, 0, stream>>>(
            reinterpret_cast<float2 *>(qk_buf_),
606 607 608 609 610
            reinterpret_cast<const float2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
611 612 613
      } else {
        SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
            reinterpret_cast<__half2 *>(qk_buf_),
614 615 616 617 618
            reinterpret_cast<const __half2 *>(bias_qk),
            batch_size,
            head_num,
            seq_len / 2,
            FINAL_MASK);
619 620 621 622 623
      }
    } else {
      SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
          qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
    }
624
  }
625 626 627
}

template <typename T>
L
Leo Chen 已提交
628
inline void MatMulWithHeadQKV(const phi::GPUContext &context,
629 630 631 632 633 634 635 636 637 638
                              int head_num,
                              int seq_len,
                              int size_per_head,
                              int batch_size,
                              bool qk_trans,
                              bool v_trans,
                              T *v_buf_,
                              const T *qk_buf_,
                              T *dst,
                              T alpha,
639 640 641 642 643
                              T beta) {
  int m = batch_size * seq_len;
  int k = head_num * size_per_head;

  typedef typename CUDATypeTraits<T>::TYPE run_type;
L
Leo Chen 已提交
644
  auto blas = phi::funcs::GetBlas<phi::GPUContext, run_type>(context);
645 646 647 648
  auto stream = context.stream();
  CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
  CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;

649 650 651 652 653 654 655 656 657 658 659 660 661
  blas.BatchedGEMM(transA,
                   transB,
                   seq_len,
                   size_per_head,
                   seq_len,
                   static_cast<run_type>(alpha),
                   reinterpret_cast<const run_type *>(qk_buf_),
                   reinterpret_cast<run_type *>(v_buf_),
                   static_cast<run_type>(beta),
                   reinterpret_cast<run_type *>(dst),
                   batch_size * head_num,
                   seq_len * seq_len,
                   seq_len * size_per_head);
662 663 664
}

template <typename T>
L
Leo Chen 已提交
665 666 667 668 669 670 671 672 673 674
void MultiHeadGPUComputeFunctor<T>::operator()(const phi::GPUContext &dev_ctx,
                                               int batch,
                                               int seq_len,
                                               int head_num,
                                               int head_size,
                                               T *qkptr,
                                               const T *bias_qk_ptr,
                                               T *tptr,
                                               T alpha,
                                               T beta) {
675 676 677 678 679 680 681
  auto stream = dev_ctx.stream();
  const int tsize = batch * head_num * seq_len * head_size;

  T *qptr = tptr;
  T *kptr = qptr + tsize;
  T *vptr = kptr + tsize;
  // batch gemm stride, softmaxwithscale.
682 683 684 685 686 687 688 689 690 691 692 693 694
  MatMulWithHeadQK<T>(dev_ctx,
                      head_num,
                      seq_len,
                      head_size,
                      batch,
                      false,
                      true,
                      qptr,
                      kptr,
                      qkptr,
                      bias_qk_ptr,
                      alpha,
                      beta);
695
  // batch gemm stride, transpose.
696 697 698 699 700 701 702 703 704 705 706 707
  MatMulWithHeadQKV<T>(dev_ctx,
                       head_num,
                       seq_len,
                       head_size,
                       batch,
                       false,
                       false,
                       vptr,
                       qkptr,
                       tptr,
                       T(1.0),
                       beta);
708 709 710 711
}

template class MultiHeadGPUComputeFunctor<float>;

712
// device function 'operator()' is not supportted until cuda 10.0
713
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
714
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
715 716 717 718
template class MultiHeadGPUComputeFunctor<half>;
#endif

template <typename T, unsigned TPB>
719 720 721 722 723
__global__ void SkipLayerNormSmallKernel(int num,
                                         int hidden,
                                         const T *input1,
                                         const T *input2,
                                         T *output,
724 725 726
                                         const T *scale,
                                         const T *bias,
                                         T eps) {
727 728 729
  const T rld = T(1) / T(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
730
  phi::funcs::kvp<T> thread_data(0, 0);
731 732 733 734 735
  const int idx = offset + threadIdx.x;
  T val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const T rldval = rld * val;
736
    thread_data =
737
        pair_sum(thread_data, phi::funcs::kvp<T>(rldval, rldval * val));
738
  }
739 740
  LayerNormSmall<T, TPB>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
741 742
}

743 744
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormSmallKernel
745
template <>
746 747 748 749 750
__global__ void SkipLayerNormSmallKernel<half, 32>(int num,
                                                   int hidden,
                                                   const half *input1,
                                                   const half *input2,
                                                   half *output,
751 752 753
                                                   const half *scale,
                                                   const half *bias,
                                                   half eps) {
754 755 756 757
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
758
  phi::funcs::kvp<half> thread_data(0, 0);
759 760 761 762 763
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
764
    thread_data =
765
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
766
  }
767 768
  LayerNormSmall<half, 32>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
769 770 771 772
#endif
}

template <>
773 774 775 776 777
__global__ void SkipLayerNormSmallKernel<half, 128>(int num,
                                                    int hidden,
                                                    const half *input1,
                                                    const half *input2,
                                                    half *output,
778 779 780
                                                    const half *scale,
                                                    const half *bias,
                                                    half eps) {
781 782 783 784
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
785
  phi::funcs::kvp<half> thread_data(0, 0);
786 787 788 789 790
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
791
    thread_data =
792
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
793
  }
794 795
  LayerNormSmall<half, 128>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
796 797 798 799
#endif
}

template <>
800 801 802 803 804
__global__ void SkipLayerNormSmallKernel<half, 384>(int num,
                                                    int hidden,
                                                    const half *input1,
                                                    const half *input2,
                                                    half *output,
805 806 807
                                                    const half *scale,
                                                    const half *bias,
                                                    half eps) {
808 809 810 811
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
812
  phi::funcs::kvp<half> thread_data(0, 0);
813 814 815 816 817
  const int idx = offset + threadIdx.x;
  half val = 0;
  if (threadIdx.x < hidden) {
    val = input1[idx] + input2[idx];
    const half rldval = rld * val;
818
    thread_data =
819
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
820
  }
821 822
  LayerNormSmall<half, 384>(
      val, thread_data, hidden, idx, bias, scale, output, eps);
823 824
#endif
}
825
#endif  // @} End Half kernel: SkipLayerNormSmallKernel
826

827
template <typename T, unsigned TPB>
828 829 830 831 832
__global__ void SkipLayerNormKernel(int num,
                                    int hidden,
                                    const T *input1,
                                    const T *input2,
                                    T *output,
833 834 835
                                    const T *scale,
                                    const T *bias,
                                    T eps) {
836 837 838
  const T rld = T(1) / T(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
839
  phi::funcs::kvp<T> thread_data(0, 0);
840 841 842 843 844

  for (int it = threadIdx.x; it < hidden; it += TPB) {
    const int idx = offset + it;
    const T val = input1[idx] + input2[idx];
    const T rldval = rld * val;
845
    thread_data =
846
        pair_sum(thread_data, phi::funcs::kvp<T>(rldval, rldval * val));
847 848 849 850 851
    output[idx] = val;
  }
  LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}

852 853
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormKernel
854
template <>
855 856
__global__ void SkipLayerNormKernel<half, 256>(int num,
                                               int hidden,
857
                                               const half *input1,
858 859
                                               const half *input2,
                                               half *output,
860 861 862
                                               const half *scale,
                                               const half *bias,
                                               half eps) {
863 864 865 866
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
  const half rld = half(1) / half(hidden);
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
867
  phi::funcs::kvp<half> thread_data(0, 0);
868 869 870 871 872

  for (int it = threadIdx.x; it < hidden; it += 256) {
    const int idx = offset + it;
    const half val = input1[idx] + input2[idx];
    const half rldval = rld * val;
873
    thread_data =
874
        pair_sum(thread_data, phi::funcs::kvp<half>(rldval, rldval * val));
875 876 877 878 879
    output[idx] = val;
  }
  LayerNorm<half, 256>(thread_data, hidden, offset, bias, scale, output, eps);
#endif
}
880
#endif  // @} End Half kernel: SkipLayerNormKernel
881

882
template <typename T, typename T2, unsigned TPB>
883 884 885 886 887
__global__ void SkipLayerNormKernel2(int num,
                                     int hidden,
                                     const T2 *input1,
                                     const T2 *input2,
                                     T2 *output,
888 889
                                     const T2 *scale,
                                     const T2 *bias,
890 891 892 893
                                     float eps) {
  const T rld = T(0.5f / hidden);  // because hidden is hidden/2
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
894
  phi::funcs::kvp<T> thread_data(0, 0);
895 896 897 898 899

  for (int it = threadIdx.x; it < hidden; it += TPB) {
    const int idx = offset + it;
    const T2 val2 = input1[idx] + input2[idx];
    thread_data = pair_sum(
900
        thread_data,
901 902
        phi::funcs::kvp<T>(rld * (val2.x + val2.y),
                           rld * val2.x * val2.x + rld * val2.y * val2.y));
903 904 905 906 907
    output[idx] = val2;
  }
  LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}

908 909
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__  // @{ Half kernel: SkipLayerNormKernel2
910
template <>
911 912 913 914 915
__global__ void SkipLayerNormKernel2<half, half2, 256>(int num,
                                                       int hidden,
                                                       const half2 *input1,
                                                       const half2 *input2,
                                                       half2 *output,
916 917
                                                       const half2 *scale,
                                                       const half2 *bias,
918
                                                       float eps) {
919 920 921 922 923
// operator "+" of half only suppotted after cuda version 10.0
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
  const half rld = half(0.5f / hidden);  // because hidden is hidden/2
  const int offset = blockIdx.x * hidden;
  cub::Sum pair_sum;
924
  phi::funcs::kvp<half> thread_data(0, 0);
925 926 927 928 929

  for (int it = threadIdx.x; it < hidden; it += 256) {
    const int idx = offset + it;
    const half2 val2 = input1[idx] + input2[idx];
    thread_data = pair_sum(
930
        thread_data,
931 932
        phi::funcs::kvp<half>(rld * (val2.x + val2.y),
                              rld * val2.x * val2.x + rld * val2.y * val2.y));
933 934
    output[idx] = val2;
  }
935 936
  LayerNorm2<half, half2, 256>(
      thread_data, hidden, offset, bias, scale, output, eps);
937 938
#endif
}
939
#endif  // @} End Half kernel: SkipLayerNormKernel2
940

941
template <typename T>
942 943 944 945
void SkipLayerNormFunctor<T>::operator()(const int num,
                                         const int hidden,
                                         const T *input1,
                                         const T *input2,
946 947
                                         const T *scale,
                                         const T *bias,
948
                                         T *output,
949
                                         float eps,
950
                                         gpuStream_t stream) {
951 952 953 954 955 956 957 958 959 960 961 962 963 964 965
  int block = num / hidden;
  if (hidden <= 32) {
    const int threads = 32;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else if (hidden <= 128) {
    const int threads = 128;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else if (hidden == 384) {
    const int threads = 384;
    SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
        num, hidden, input1, input2, output, scale, bias, eps);
  } else {
    const int threads = 256;
966 967
    if (hidden % 2 == 0) {
      if (std::is_same<T, float>::value) {
968 969
        SkipLayerNormKernel2<float, float2, threads>
            <<<block, threads, 0, stream>>>(
970 971 972
                num,
                hidden / 2,
                reinterpret_cast<const float2 *>(input1),
973 974 975
                reinterpret_cast<const float2 *>(input2),
                reinterpret_cast<float2 *>(output),
                reinterpret_cast<const float2 *>(scale),
976 977
                reinterpret_cast<const float2 *>(bias),
                eps);
978 979
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#ifndef __HIPCC__
980
      } else if (std::is_same<T, __half>::value) {
981 982
        SkipLayerNormKernel2<__half, __half2, threads>
            <<<block, threads, 0, stream>>>(
983 984 985
                num,
                hidden / 2,
                reinterpret_cast<const __half2 *>(input1),
986 987
                reinterpret_cast<const __half2 *>(input2),
                reinterpret_cast<__half2 *>(output),
988 989
                reinterpret_cast<const __half2 *>(scale),
                reinterpret_cast<const __half2 *>(bias),
990
                eps);
991
#endif
992 993 994 995 996 997 998 999
      } else {
        assert(false);
        // should not be here
      }
    } else {
      SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
          num, hidden, input1, input2, output, scale, bias, eps);
    }
1000 1001 1002 1003 1004
  }
}

template class SkipLayerNormFunctor<float>;

1005
// device function 'operator()' is not supportted until cuda 10.0
1006
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
1007
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
1008 1009 1010 1011 1012 1013
template class SkipLayerNormFunctor<half>;
#endif

}  // namespace math
}  // namespace operators
}  // namespace paddle