cuda_device_function.h 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

17 18
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
19
#include "paddle/fluid/platform/complex.h"
20
#include "paddle/fluid/platform/float16.h"
21 22 23 24

namespace paddle {
namespace platform {

25 26 27
#ifdef PADDLE_WITH_HIP
#define CREATE_SHFL_MASK(mask, predicate) mask = __ballot((predicate))
#else
28 29 30
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
C
chengduoZH 已提交
31 32
#endif

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
inline static int RoundToPowerOfTwo(int dim) {
  if (dim > 512) {
    return 1024;
  } else if (dim > 256) {
    return 512;
  } else if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
  }
}

#define CUDA_LAUNCH_KERNEL_BASE(dim, ...)  \
  case (dim): {                            \
    constexpr auto kPowerOfTwoDim = (dim); \
    __VA_ARGS__;                           \
  } break

55 56 57 58 59 60
#define CUDA_LAUNCH_KERNEL_HELPER(...)          \
  CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \
  CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__);   \
61 62
  CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);

C
chengduoZH 已提交
63
template <typename T>
C
chengduoZH 已提交
64
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
65 66
                                                 int delta,
                                                 int width = warpSize) {
T
tianshuo78520a 已提交
67
#if defined(PADDLE_WITH_HIP)
C
chengduoZH 已提交
68 69
  return __shfl_down(val, delta, width);
#else
70
  return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
C
chengduoZH 已提交
71
#endif
C
chengduoZH 已提交
72 73
}

74 75 76
template <typename T>
__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val,
                                                int width = warpSize) {
T
tianshuo78520a 已提交
77
#if defined(PADDLE_WITH_HIP)
78 79 80 81 82 83
  return __shfl_xor(val, width);
#else
  return __shfl_xor_sync(mask, val, width);
#endif
}

T
tianshuo78520a 已提交
84
#if defined(PADDLE_WITH_HIP)
85 86 87 88
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
89 90
  return float16(__shfl_down(static_cast<float>(val),
                             static_cast<unsigned>(delta), width));
91
}
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
    unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
  float real = __shfl_down(val.real, delta, width);
  float imag = __shfl_down(val.imag, delta, width);
  return paddle::platform::complex<float>(real, imag);
}

template <>
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
                    int delta, int width) {
  double real = __shfl_down(val.real, delta, width);
  double imag = __shfl_down(val.imag, delta, width);
  return paddle::platform::complex<double>(real, imag);
}

110 111 112
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
                                                      float16 val, int width) {
113
  return float16(__shfl_xor(static_cast<float>(val), width));
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
}

template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex<float> val, int width) {
  float real = __shfl_xor(val.real, width);
  float imag = __shfl_xor(val.imag, width);
  return paddle::platform::complex<float>(real, imag);
}

template <>
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex<double> val, int width) {
  double real = __shfl_xor(val.real, width);
  double imag = __shfl_xor(val.imag, width);
  return paddle::platform::complex<double>(real, imag);
130
}
131 132 133 134 135 136 137
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
  return float16(__shfl_down_sync(mask, static_cast<half>(val),
                                  static_cast<unsigned>(delta), width));
138
}
139 140

template <>
141 142
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
    unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
143 144 145 146
  float real = static_cast<float>(__shfl_down_sync(
      mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
  float imag = static_cast<float>(__shfl_down_sync(
      mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
147
  return paddle::platform::complex<float>(real, imag);
148 149 150
}

template <>
151 152 153
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask, paddle::platform::complex<double> val,
                    int delta, int width) {
154 155 156 157 158 159
  double real = static_cast<double>(
      __shfl_down_sync(mask, static_cast<double>(val.real),
                       static_cast<unsigned>(delta), width));
  double imag = static_cast<double>(
      __shfl_down_sync(mask, static_cast<double>(val.imag),
                       static_cast<unsigned>(delta), width));
160
  return paddle::platform::complex<double>(real, imag);
161 162
}

163 164 165 166 167
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
                                                      float16 val, int width) {
  return float16(__shfl_xor_sync(mask, static_cast<half>(val), width));
}
168 169

template <>
170 171
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex<float> val, int width) {
172 173 174 175
  float real = static_cast<float>(
      __shfl_xor_sync(mask, static_cast<float>(val.real), width));
  float imag = static_cast<float>(
      __shfl_xor_sync(mask, static_cast<float>(val.imag), width));
176
  return paddle::platform::complex<float>(real, imag);
177 178 179
}

template <>
180 181
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
    unsigned mask, paddle::platform::complex<double> val, int width) {
182 183 184 185
  double real = static_cast<double>(
      __shfl_xor_sync(mask, static_cast<double>(val.real), width));
  double imag = static_cast<double>(
      __shfl_xor_sync(mask, static_cast<double>(val.imag), width));
186
  return paddle::platform::complex<double>(real, imag);
187
}
188 189
#endif

C
chengduoZH 已提交
190
template <typename T>
C
chengduoZH 已提交
191 192
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
                                             int width = 32) {
T
tianshuo78520a 已提交
193
#if defined(PADDLE_WITH_HIP)
C
chengduoZH 已提交
194 195
  return __shfl(val, src_line, width);
#else
C
chengduoZH 已提交
196
  return __shfl_sync(mask, val, src_line, width);
197
#endif
C
chengduoZH 已提交
198
}
199 200

template <typename T>
201 202 203 204 205
HOSTDEVICE T Infinity() {
  return INFINITY;
}

template <typename T>
206
__device__ T reduceSum(T val, int tid, int len) {
207 208 209 210 211 212 213 214 215
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
#ifdef PADDLE_WITH_HIP
  const int warpSize = 64;
#else
216
  const int warpSize = 32;
217
#endif
218 219 220 221 222
  __shared__ T shm[warpSize];
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
223
    val += platform::CudaShuffleDownSync(mask, val, offset);
224 225

  if (tid < warpSize) shm[tid] = 0;
C
chengduoZH 已提交
226
  __syncthreads();
227 228 229 230 231 232 233 234 235 236 237

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }
  __syncthreads();

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
238
      val += platform::CudaShuffleDownSync(mask, val, offset);
239 240 241 242 243 244
  }
  return val;
}

}  // namespace platform
}  // namespace paddle