math_cuda_utils.h 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17

#ifdef PADDLE_WITH_CUDA
18
#include <cuda_fp16.h>
19 20 21 22 23
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24 25 26 27 28 29 30 31 32 33 34 35
#include <algorithm>

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__device__ __forceinline__ T FromFloat(float a);

template <typename T>
__device__ __forceinline__ float ToFloat(T a);

36 37 38
template <typename T>
__device__ __forceinline__ float2 ToFloat2(T a);

39 40 41
template <typename T>
__device__ __forceinline__ T exp_func(T a);

42 43 44
template <typename T>
__device__ __forceinline__ T FloatsToPair(const float a, const float b);

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename T>
struct KeyValuePair;

template <typename T>
using kvp = KeyValuePair<T>;

// from_float
template <>
__device__ __forceinline__ float FromFloat<float>(float a) {
  return a;
}

template <>
__device__ __forceinline__ half FromFloat<half>(float a) {
  return __float2half(a);
}

// to_float
template <>
__device__ __forceinline__ float ToFloat<float>(float a) {
  return a;
}

68 69 70 71 72 73
template <>
__device__ __forceinline__ float2 ToFloat2<float2>(float2 a) {
  return a;
}

template <>
74 75
__device__ __forceinline__ float2 FloatsToPair<float2>(const float a,
                                                       const float b) {
76 77 78 79 80 81 82
  return make_float2(a, b);
}

__inline__ __device__ float2 operator+(const float2 &a, const float2 &b) {
  return make_float2(a.x + b.x, a.y + b.y);
}

83 84 85 86
template <>
__device__ __forceinline__ float ToFloat<half>(half a) {
  return __half2float(a);
}
87 88 89 90 91 92 93

template <>
__device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) {
  return __half22float2(a);
}

template <>
94 95
__device__ __forceinline__ __half2 FloatsToPair<__half2>(const float a,
                                                         const float b) {
96 97
  return __floats2half2_rn(a, b);
}
98 99 100 101 102 103 104 105

template <>
__device__ __forceinline__ float exp_func<float>(float a) {
  return expf(a);
}

template <>
__device__ __forceinline__ half exp_func<half>(half a) {
106
#if defined(__HIPCC__) || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  return hexp(a);
#else
  return FromFloat<half>(expf(ToFloat<half>(a)));
#endif
}

template <>
struct KeyValuePair<float> {
  __device__ __forceinline__ KeyValuePair() {}
  __device__ __forceinline__ KeyValuePair(float k, float v)
      : key(k), value(v) {}
  __device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
    key = a.key;
    value = a.value;
  }
  float key;
  float value;
  __device__ __forceinline__ KeyValuePair
  operator+(const KeyValuePair &a) const {
    KeyValuePair tmp;
    tmp.key = key + a.key;
    tmp.value = value + a.value;
    return tmp;
  }
};

template <>
struct KeyValuePair<half> {
  __device__ __forceinline__ KeyValuePair() {}
  __device__ __forceinline__ KeyValuePair(half k, half v) : key(k), value(v) {}
  __device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
    key = a.key;
    value = a.value;
  }
  half key;
  half value;
  __device__ __forceinline__ KeyValuePair
  operator+(const KeyValuePair &a) const {
    const half2 a2 = __halves2half2(key, value);
    const half2 b2 = __halves2half2(a.key, a.value);
147
#ifdef PADDLE_WITH_CUDA
148
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
149
    const half2 res = __hadd2(a2, b2);
150 151 152 153 154 155 156 157 158
#else
    float a2_1 = __low2float(a2);
    float a2_2 = __high2float(a2);
    float b2_1 = __low2float(b2);
    float b2_2 = __high2float(b2);
    float r1 = a2_1 + b2_1;
    float r2 = a2_2 + b2_2;
    const half2 res = __floats2half2_rn(r1, r2);
#endif
159
    return KeyValuePair(res.x, res.y);
160 161 162 163
#else  // PADDLE_WITH_HIP
    const half2 res = __hadd2(a2, b2);
    return KeyValuePair(__low2half(res), __high2half(res));
#endif
164 165 166 167 168 169 170 171 172 173
  }
};

#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32

template <typename T>
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
174
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
    val += __shfl_xor(val, mask, warpSize);
#endif
  return val;
}

/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
  static __shared__ T shared[WARP_SIZE];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  val = warpReduceSum<T>(val, mask);

  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to warpSize
  int block_span = (blockDim.x + warpSize - 1) >> 5;
197
  val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
198 199 200 201 202 203 204 205
  val = warpReduceSum<T>(val, mask);

  return val;
}

template <typename T>
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
206
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
207 208 209 210 211 212 213
    val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
    val = max(val, __shfl_xor(val, mask, warpSize));
#endif
  return val;
}

214 215 216
template <typename T>
__inline__ __device__ T warpReduceMin(T val, unsigned lane_mask) {
  for (int mask = HALF_WARP; mask > 0; mask >>= 1)
217
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
218 219 220 221 222 223 224 225 226 227 228
    val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
    val = min(val, __shfl_xor(val, mask, warpSize));
#endif
  return val;
}

/* Calculate the minimum of all elements in a warp when actual quantity of
 * threads are less than warpSize.*/
template <typename T>
__inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
229
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
230 231 232 233 234 235 236 237
  T warp_val = __shfl_sync(lane_mask, val, 0, warpSize);
#else
  T warp_val = __shfl(
      val, 0, warpSize);  // To fullfill the data in each thread of this warp.
#endif
  warp_val = val;

  for (int offset = HALF_WARP; offset > 0; offset >>= 1)
238
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
239 240 241 242 243 244 245 246
    warp_val =
        min(warp_val, __shfl_down_sync(lane_mask, warp_val, offset, warpSize));
#else
    warp_val = min(warp_val, __shfl_down(warp_val, offset, warpSize));
#endif
  return warp_val;
}

247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
  static __shared__ T shared[WARP_SIZE];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  val = warpReduceMax(val, mask);

  if (lane == 0) shared[wid] = val;

  __syncthreads();

  // align block_span to warpSize
  int block_span = (blockDim.x + warpSize - 1) >> 5;
262
  val = (lane < block_span) ? shared[lane] : -1e10f;
263 264 265 266 267
  val = warpReduceMax(val, mask);

  return val;
}

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
/* Calculate the minimum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMin(T val, unsigned mask) {
  static __shared__ T shared[WARP_SIZE];
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  val = warpReduceMin(val, mask);
  if (lane == 0) shared[wid] = val;
  __syncthreads();

  // align block_span to warpSize
  int block_span = (blockDim.x + warpSize - 1) >> 5;
  val = (lane < block_span) ? shared[lane] : 1e10f;
  val = warpReduceMin(val, mask);

  return val;
}

/* Calculate the minimum of all elements in a warp when actual quantity of
 * threads are less than warpSize.*/
template <typename T>
__inline__ __device__ T PartialBlockReduceMin(T val, unsigned mask) {
  static __shared__ T shared[WARP_SIZE];
  static __shared__ T min_value;
  int lane = threadIdx.x & 0x1f;
  int wid = threadIdx.x >> 5;

  val = PartialWarpReduceMin(val, mask);
  if (lane == 0) shared[wid] = val;
  __syncthreads();

  shared[lane] = PartialWarpReduceMin(shared[lane], mask);
301 302 303 304 305
#if defined(PADDLE_WITH_HIP)
  // HIP do not support __syncwarp, using __syncthreads() instead is ok,
  // although bringing a few performance decrease.
  __syncthreads();
#else
306
  __syncwarp();
307
#endif
308

309
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
310 311 312 313 314 315 316
  val = __shfl_sync(mask, shared[lane], 0, warpSize);
#else
  val = __shfl(shared[lane], 0, warpSize);
#endif
  return val;
}

317 318 319
}  // namespace math
}  // namespace operators
}  // namespace paddle