cuda_device_function.h 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

17 18
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
19 20
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
21
#include "paddle/fluid/platform/float16.h"
22 23 24 25

namespace paddle {
namespace platform {

26 27 28
#ifdef PADDLE_WITH_HIP
#define CREATE_SHFL_MASK(mask, predicate) mask = __ballot((predicate))
#else
29 30 31 32 33 34
#if CUDA_VERSION < 9000
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
C
chengduoZH 已提交
35
#endif
36
#endif
C
chengduoZH 已提交
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
inline static int RoundToPowerOfTwo(int dim) {
  if (dim > 512) {
    return 1024;
  } else if (dim > 256) {
    return 512;
  } else if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
  }
}

#define CUDA_LAUNCH_KERNEL_BASE(dim, ...)  \
  case (dim): {                            \
    constexpr auto kPowerOfTwoDim = (dim); \
    __VA_ARGS__;                           \
  } break

60 61 62 63 64 65
#define CUDA_LAUNCH_KERNEL_HELPER(...)          \
  CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \
  CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__);   \
66 67
  CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);

C
chengduoZH 已提交
68
template <typename T>
C
chengduoZH 已提交
69
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
70 71
                                                 int delta,
                                                 int width = warpSize) {
72
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION < 9000
C
chengduoZH 已提交
73 74
  return __shfl_down(val, delta, width);
#else
75
  return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
C
chengduoZH 已提交
76
#endif
C
chengduoZH 已提交
77 78
}

79 80 81
template <typename T>
__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val,
                                                int width = warpSize) {
82
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION < 9000
83 84 85 86 87 88
  return __shfl_xor(val, width);
#else
  return __shfl_xor_sync(mask, val, width);
#endif
}

89
// CUDA 9.0 have native compatible float16 shfl_down
90
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION < 9000
91 92 93 94
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
95 96 97 98
#ifdef PADDLE_WITH_HIP
  return float16(__shfl_down(static_cast<float>(val),
                             static_cast<unsigned>(delta), width));
#else
99 100
  return float16(
      __shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
101
#endif
102
}
103 104 105
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
                                                      float16 val, int width) {
106 107 108
#ifdef PADDLE_WITH_HIP
  return float16(__shfl_xor(static_cast<float>(val), width));
#else
109
  return float16(__shfl_xor(static_cast<half>(val), width));
110
#endif
111
}
112 113 114 115 116 117 118
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
  return float16(__shfl_down_sync(mask, static_cast<half>(val),
                                  static_cast<unsigned>(delta), width));
119
}
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleDownSync(
    unsigned mask, paddle::platform::complex64 val, int delta, int width) {
  float real = static_cast<float>(__shfl_down_sync(
      mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
  float imag = static_cast<float>(__shfl_down_sync(
      mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
  return paddle::platform::complex64(real, imag);
}

template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleDownSync(
    unsigned mask, paddle::platform::complex128 val, int delta, int width) {
  double real = static_cast<double>(
      __shfl_down_sync(mask, static_cast<double>(val.real),
                       static_cast<unsigned>(delta), width));
  double imag = static_cast<double>(
      __shfl_down_sync(mask, static_cast<double>(val.imag),
                       static_cast<unsigned>(delta), width));
  return paddle::platform::complex128(real, imag);
}

143 144 145 146 147
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
                                                      float16 val, int width) {
  return float16(__shfl_xor_sync(mask, static_cast<half>(val), width));
}
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

template <>
__forceinline__ __device__ paddle::platform::complex64 CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex64 val, int width) {
  float real = static_cast<float>(
      __shfl_xor_sync(mask, static_cast<float>(val.real), width));
  float imag = static_cast<float>(
      __shfl_xor_sync(mask, static_cast<float>(val.imag), width));
  return paddle::platform::complex64(real, imag);
}

template <>
__forceinline__ __device__ paddle::platform::complex128 CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex128 val, int width) {
  double real = static_cast<double>(
      __shfl_xor_sync(mask, static_cast<double>(val.real), width));
  double imag = static_cast<double>(
      __shfl_xor_sync(mask, static_cast<double>(val.imag), width));
  return paddle::platform::complex128(real, imag);
}
168 169
#endif

C
chengduoZH 已提交
170
template <typename T>
C
chengduoZH 已提交
171 172
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
                                             int width = 32) {
173
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION < 9000
C
chengduoZH 已提交
174 175
  return __shfl(val, src_line, width);
#else
C
chengduoZH 已提交
176
  return __shfl_sync(mask, val, src_line, width);
177
#endif
C
chengduoZH 已提交
178
}
179 180

template <typename T>
181 182 183 184 185
HOSTDEVICE T Infinity() {
  return INFINITY;
}

template <typename T>
186
__device__ T reduceSum(T val, int tid, int len) {
187 188 189 190 191 192 193 194 195
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
#ifdef PADDLE_WITH_HIP
  const int warpSize = 64;
#else
196
  const int warpSize = 32;
197
#endif
198 199 200 201 202
  __shared__ T shm[warpSize];
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
203
    val += platform::CudaShuffleDownSync(mask, val, offset);
204 205

  if (tid < warpSize) shm[tid] = 0;
C
chengduoZH 已提交
206
  __syncthreads();
207 208 209 210 211 212 213 214 215 216 217

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }
  __syncthreads();

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
218
      val += platform::CudaShuffleDownSync(mask, val, offset);
219 220 221 222 223 224
  }
  return val;
}

}  // namespace platform
}  // namespace paddle