math_cuda_utils.h 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17

#ifdef PADDLE_WITH_CUDA
18
#include <cuda_fp16.h>
19 20 21 22 23
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24 25
#include <algorithm>

26
namespace phi {
27
namespace funcs {
28 29 30 31 32 33 34

template <typename T>
__device__ __forceinline__ T FromFloat(float a);

template <typename T>
__device__ __forceinline__ float ToFloat(T a);

35 36 37
template <typename T>
__device__ __forceinline__ float2 ToFloat2(T a);

38 39 40
template <typename T>
__device__ __forceinline__ T exp_func(T a);

41 42 43
template <typename T>
__device__ __forceinline__ T FloatsToPair(const float a, const float b);

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
template <typename T>
struct KeyValuePair;

template <typename T>
using kvp = KeyValuePair<T>;

// from_float
template <>
__device__ __forceinline__ float FromFloat<float>(float a) {
  return a;
}

template <>
__device__ __forceinline__ half FromFloat<half>(float a) {
  return __float2half(a);
}

// to_float
template <>
__device__ __forceinline__ float ToFloat<float>(float a) {
  return a;
}

67 68 69 70 71 72
template <>
__device__ __forceinline__ float2 ToFloat2<float2>(float2 a) {
  return a;
}

template <>
73 74
__device__ __forceinline__ float2 FloatsToPair<float2>(const float a,
                                                       const float b) {
75 76 77 78 79 80 81
  return make_float2(a, b);
}

__inline__ __device__ float2 operator+(const float2 &a, const float2 &b) {
  return make_float2(a.x + b.x, a.y + b.y);
}

82 83 84 85
template <>
__device__ __forceinline__ float ToFloat<half>(half a) {
  return __half2float(a);
}
86 87 88 89 90 91 92

template <>
__device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) {
  return __half22float2(a);
}

template <>
93 94
__device__ __forceinline__ __half2 FloatsToPair<__half2>(const float a,
                                                         const float b) {
95 96
  return __floats2half2_rn(a, b);
}
97 98 99 100 101 102 103 104

template <>
__device__ __forceinline__ float exp_func<float>(float a) {
  return expf(a);
}

template <>
__device__ __forceinline__ half exp_func<half>(half a) {
105
#if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  return hexp(a);
#else
  return FromFloat<half>(expf(ToFloat<half>(a)));
#endif
}

template <>
struct KeyValuePair<float> {
  __device__ __forceinline__ KeyValuePair() {}
  __device__ __forceinline__ KeyValuePair(float k, float v)
      : key(k), value(v) {}
  __device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
    key = a.key;
    value = a.value;
  }
  float key;
  float value;
  __device__ __forceinline__ KeyValuePair
  operator+(const KeyValuePair &a) const {
    KeyValuePair tmp;
    tmp.key = key + a.key;
    tmp.value = value + a.value;
    return tmp;
  }
};

template <>
struct KeyValuePair<half> {
  __device__ __forceinline__ KeyValuePair() {}
  __device__ __forceinline__ KeyValuePair(half k, half v) : key(k), value(v) {}
  __device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
    key = a.key;
    value = a.value;
  }
  half key;
  half value;
  __device__ __forceinline__ KeyValuePair
  operator+(const KeyValuePair &a) const {
    const half2 a2 = __halves2half2(key, value);
    const half2 b2 = __halves2half2(a.key, a.value);
146
#ifdef PADDLE_WITH_CUDA
147
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
148
    const half2 res = __hadd2(a2, b2);
149 150 151 152 153 154 155 156 157
#else
    float a2_1 = __low2float(a2);
    float a2_2 = __high2float(a2);
    float b2_1 = __low2float(b2);
    float b2_2 = __high2float(b2);
    float r1 = a2_1 + b2_1;
    float r2 = a2_2 + b2_2;
    const half2 res = __floats2half2_rn(r1, r2);
#endif
158
    return KeyValuePair(res.x, res.y);
159 160 161 162
#else  // PADDLE_WITH_HIP
    const half2 res = __hadd2(a2, b2);
    return KeyValuePair(__low2half(res), __high2half(res));
#endif
163 164 165
  }
};

R
ronnywang 已提交
166 167 168 169 170 171 172 173 174 175 176 177
// NOTE(wangran16): The warpSize variable is of type int and contains the warp
// size (in threads) for the target device. Note that all current NVIDIA devices
// return 32 for this variable, and all current AMD devices return 64. Device
// code should use the warpSize built-in to develop portable wave-aware code.
#ifdef PADDLE_WITH_HIP
#define FINAL_MASK 0xffffffffffffffffUL
#define HALF_WARP 32
#define WARP_SIZE 64
#define WARP_SIZE_WIDTH 6
#define WARP_SIZE_WIDTH_MASK 0x3f
typedef u_int64_t warp_mask_t;
#else
178 179 180
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
R
ronnywang 已提交
181 182 183 184
#define WARP_SIZE_WIDTH 5
#define WARP_SIZE_WIDTH_MASK 0x1f
typedef unsigned warp_mask_t;
#endif
185 186

template <typename T>
R
ronnywang 已提交
187
__inline__ __device__ T WarpReduceSum(T val, warp_mask_t lane_mask) {
188
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
189 190 191 192 193
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
    val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
    val += __shfl_xor(val, mask, warpSize);
#endif
194 195 196 197 198
  return val;
}

/* Calculate the sum of all elements in a block */
template <typename T>
R
ronnywang 已提交
199
__inline__ __device__ T BlockReduceSum(T val, warp_mask_t mask) {
200
  static __shared__ T shared[WARP_SIZE];
R
ronnywang 已提交
201 202
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;
203

204
  val = WarpReduceSum<T>(val, mask);
205

206
  __syncthreads();
207 208 209 210 211
  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to warpSize
R
ronnywang 已提交
212
  int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
213
  val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
214
  val = WarpReduceSum<T>(val, mask);
215 216 217 218

  return val;
}

219 220 221 222 223 224 225 226
/*
WarpReduce multi values.
*/
template <typename T, int NUM>
__inline__ __device__ T WarpReduceSumV2(T *val) {
#pragma unroll
  for (int i = 0; i < NUM; i++) {
#pragma unroll
R
ronnywang 已提交
227 228
    for (int mask = HALF_WARP; mask > 0; mask >>= 1)
      val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE);
229 230 231 232 233 234 235
  }
  return (T)(0.0f);
}

template <typename T, int NUM>
__inline__ __device__ T BlockReduceSumV2(T *val) {
  static __shared__ T shared[NUM][33];
R
ronnywang 已提交
236 237
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;
238 239 240 241 242 243 244 245 246 247 248 249

  WarpReduceSumV2<T, NUM>(val);

  if (lane == 0) {
#pragma unroll
    for (int i = 0; i < NUM; i++) {
      shared[i][wid] = val[i];
    }
  }

  __syncthreads();

R
ronnywang 已提交
250
  bool is_mask = threadIdx.x < (blockDim.x / static_cast<float>(WARP_SIZE));
251 252 253 254 255 256 257 258
#pragma unroll
  for (int i = 0; i < NUM; i++) {
    val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
  }
  WarpReduceSumV2<T, NUM>(val);
  return (T)0.0f;
}

259
template <typename T>
R
ronnywang 已提交
260
__inline__ __device__ T WarpReduceMax(T val, warp_mask_t lane_mask) {
261
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
262 263 264 265 266
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
    val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
    val = max(val, __shfl_xor(val, mask, warpSize));
#endif
267 268 269
  return val;
}

270 271 272 273 274
template <typename T, int NUM>
__inline__ __device__ T WarpReduceMaxV2(T *val) {
#pragma unroll
  for (int i = 0; i < NUM; i++) {
#pragma unroll
R
ronnywang 已提交
275 276 277
    for (int mask = HALF_WARP; mask > 0; mask >>= 1)
      val[i] =
          max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE));
278 279 280 281
  }
  return (T)(0.0f);
}

282
template <typename T>
R
ronnywang 已提交
283
__inline__ __device__ T WarpReduceMin(T val, warp_mask_t lane_mask) {
284
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
285 286 287 288 289
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
    val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
    val = min(val, __shfl_xor(val, mask, warpSize));
#endif
290 291 292 293 294 295
  return val;
}

/* Calculate the minimum of all elements in a warp when actual quantity of
 * threads are less than warpSize.*/
template <typename T>
R
ronnywang 已提交
296
__inline__ __device__ T PartialWarpReduceMin(T val, warp_mask_t lane_mask) {
297
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
298 299 300 301 302 303 304 305
  T warp_val = __shfl_sync(lane_mask, val, 0, warpSize);
#else
  T warp_val = __shfl(
      val, 0, warpSize);  // To fullfill the data in each thread of this warp.
#endif
  warp_val = val;

  for (int offset = HALF_WARP; offset > 0; offset >>= 1)
306
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
307 308 309 310 311 312 313 314
    warp_val =
        min(warp_val, __shfl_down_sync(lane_mask, warp_val, offset, warpSize));
#else
    warp_val = min(warp_val, __shfl_down(warp_val, offset, warpSize));
#endif
  return warp_val;
}

315 316
/* Calculate the maximum of all elements in a block */
template <typename T>
R
ronnywang 已提交
317
__inline__ __device__ T BlockReduceMax(T val, warp_mask_t mask) {
318
  static __shared__ T shared[WARP_SIZE];
R
ronnywang 已提交
319 320
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;
321

322
  val = WarpReduceMax(val, mask);
323 324 325 326 327 328

  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to warpSize
R
ronnywang 已提交
329
  int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
330
  val = (lane < block_span) ? shared[lane] : -1e10f;
331
  val = WarpReduceMax(val, mask);
332 333 334 335

  return val;
}

336 337
template <typename T, int NUM>
__inline__ __device__ T BlockReduceMaxV2(T *val) {
R
ronnywang 已提交
338 339 340
  static __shared__ T shared[WARP_SIZE][NUM];
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;  // in-warp idx
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;       // warp idx
341 342 343 344 345 346 347 348 349 350 351 352 353 354

  WarpReduceMaxV2<T, NUM>(val);  // get maxx in each warp

  if (lane == 0) {  // record in-warp maxx by warp Idx
#pragma unroll
    for (int i = 0; i < NUM; i++) {
      shared[wid][i] = val[i];
    }
  }

  __syncthreads();

  // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
  // blockDim.x is not divided by 32
R
ronnywang 已提交
355
  bool is_mask = threadIdx.x < (blockDim.x / static_cast<float>(WARP_SIZE));
356 357 358 359 360 361 362 363 364
#pragma unroll
  for (int i = 0; i < NUM; i++) {
    val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
  }
  WarpReduceMaxV2<T, NUM>(val);

  return (T)0.0f;
}

365 366
/* Calculate the minimum of all elements in a block */
template <typename T>
R
ronnywang 已提交
367
__inline__ __device__ T BlockReduceMin(T val, warp_mask_t mask) {
368
  static __shared__ T shared[WARP_SIZE];
R
ronnywang 已提交
369 370
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;
371

372
  val = WarpReduceMin(val, mask);
373 374 375 376
  if (lane == 0) shared[wid] = val;
  __syncthreads();

  // align block_span to warpSize
R
ronnywang 已提交
377
  int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
378
  val = (lane < block_span) ? shared[lane] : 1e10f;
379
  val = WarpReduceMin(val, mask);
380 381 382 383 384 385 386

  return val;
}

/* Calculate the minimum of all elements in a warp when actual quantity of
 * threads are less than warpSize.*/
template <typename T>
R
ronnywang 已提交
387
__inline__ __device__ T PartialBlockReduceMin(T val, warp_mask_t mask) {
388 389
  static __shared__ T shared[WARP_SIZE];
  static __shared__ T min_value;
R
ronnywang 已提交
390 391
  int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
  int wid = threadIdx.x >> WARP_SIZE_WIDTH;
392 393 394 395 396 397

  val = PartialWarpReduceMin(val, mask);
  if (lane == 0) shared[wid] = val;
  __syncthreads();

  shared[lane] = PartialWarpReduceMin(shared[lane], mask);
398 399 400 401 402
#if defined(PADDLE_WITH_HIP)
  // HIP do not support __syncwarp, using __syncthreads() instead is ok,
  // although bringing a few performance decrease.
  __syncthreads();
#else
403
  __syncwarp();
404
#endif
405

406
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
407 408 409 410 411 412 413
  val = __shfl_sync(mask, shared[lane], 0, warpSize);
#else
  val = __shfl(shared[lane], 0, warpSize);
#endif
  return val;
}

414
}  // namespace funcs
415
}  // namespace phi