cuda_device_function.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16

17
#include <cuda.h>
18 19 20 21
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
22 23 24 25 26 27 28 29 30 31

namespace paddle {
namespace platform {

#if CUDA_VERSION < 9000
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
C
chengduoZH 已提交
32 33
#endif

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
inline static int RoundToPowerOfTwo(int dim) {
  if (dim > 512) {
    return 1024;
  } else if (dim > 256) {
    return 512;
  } else if (dim > 128) {
    return 256;
  } else if (dim > 64) {
    return 128;
  } else if (dim > 32) {
    return 64;
  } else {
    return 32;
  }
}

#define CUDA_LAUNCH_KERNEL_BASE(dim, ...)  \
  case (dim): {                            \
    constexpr auto kPowerOfTwoDim = (dim); \
    __VA_ARGS__;                           \
  } break

56 57 58 59 60 61
#define CUDA_LAUNCH_KERNEL_HELPER(...)          \
  CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \
  CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__);  \
  CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__);   \
62 63
  CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);

C
chengduoZH 已提交
64
template <typename T>
C
chengduoZH 已提交
65 66 67 68 69
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
                                                 int delta, int width = 32) {
#if CUDA_VERSION < 9000
  return __shfl_down(val, delta, width);
#else
70
  return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
C
chengduoZH 已提交
71
#endif
C
chengduoZH 已提交
72 73
}

74 75 76 77 78 79
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
80 81 82 83 84 85 86 87 88 89
  return float16(
      __shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
}
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
  return float16(__shfl_down_sync(mask, static_cast<half>(val),
                                  static_cast<unsigned>(delta), width));
90 91 92
}
#endif

C
chengduoZH 已提交
93
template <typename T>
C
chengduoZH 已提交
94 95 96 97 98
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
                                             int width = 32) {
#if CUDA_VERSION < 9000
  return __shfl(val, src_line, width);
#else
C
chengduoZH 已提交
99
  return __shfl_sync(mask, val, src_line, width);
100
#endif
C
chengduoZH 已提交
101
}
102 103

template <typename T>
104 105 106 107 108
HOSTDEVICE T Infinity() {
  return INFINITY;
}

template <typename T>
109 110 111 112 113 114 115 116 117 118 119 120 121
__device__ T reduceSum(T val, int tid, int len) {
  // NOTE(zcd): The warp size should be taken from the
  // parameters of the GPU but not specified as 32 simply.
  // To make the reduceSum more efficiently,
  // I use Warp-Level Parallelism and assume the Warp size
  // is 32 which may be different for different GPU,
  // but most card's warp size is 32.
  const int warpSize = 32;
  __shared__ T shm[warpSize];
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
122
    val += platform::CudaShuffleDownSync(mask, val, offset);
123 124

  if (tid < warpSize) shm[tid] = 0;
C
chengduoZH 已提交
125
  __syncthreads();
126 127 128 129 130 131 132 133 134 135 136

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }
  __syncthreads();

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
137
      val += platform::CudaShuffleDownSync(mask, val, offset);
138 139 140 141 142 143
  }
  return val;
}

}  // namespace platform
}  // namespace paddle