cuda_device_function.h 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cuda.h>
17 18 19 20
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace platform {

#if CUDA_VERSION < 9000
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
C
chengduoZH 已提交
31 32
#endif

C
chengduoZH 已提交
33
template <typename T>
C
chengduoZH 已提交
34 35 36 37 38
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
                                                 int delta, int width = 32) {
#if CUDA_VERSION < 9000
  return __shfl_down(val, delta, width);
#else
39
  return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
C
chengduoZH 已提交
40
#endif
C
chengduoZH 已提交
41 42
}

43 44 45 46 47 48
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
49 50 51 52 53 54 55 56 57 58
  return float16(
      __shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
}
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
                                                       float16 val, int delta,
                                                       int width) {
  return float16(__shfl_down_sync(mask, static_cast<half>(val),
                                  static_cast<unsigned>(delta), width));
59 60 61
}
#endif

C
chengduoZH 已提交
62
template <typename T>
C
chengduoZH 已提交
63 64 65 66 67
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
                                             int width = 32) {
#if CUDA_VERSION < 9000
  return __shfl(val, src_line, width);
#else
C
chengduoZH 已提交
68
  return __shfl_sync(mask, val, src_line, width);
69
#endif
C
chengduoZH 已提交
70
}
71 72

template <typename T>
73 74 75 76 77
HOSTDEVICE T Infinity() {
  return INFINITY;
}

template <typename T>
78 79 80 81 82 83 84 85 86 87 88 89 90
__device__ T reduceSum(T val, int tid, int len) {
  // NOTE(zcd): The warp size should be taken from the
  // parameters of the GPU but not specified as 32 simply.
  // To make the reduceSum more efficiently,
  // I use Warp-Level Parallelism and assume the Warp size
  // is 32 which may be different for different GPU,
  // but most card's warp size is 32.
  const int warpSize = 32;
  __shared__ T shm[warpSize];
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
91
    val += platform::CudaShuffleDownSync(mask, val, offset);
92 93

  if (tid < warpSize) shm[tid] = 0;
C
chengduoZH 已提交
94
  __syncthreads();
95 96 97 98 99 100 101 102 103 104 105

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }
  __syncthreads();

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
C
chengduoZH 已提交
106
      val += platform::CudaShuffleDownSync(mask, val, offset);
107 108 109 110 111 112
  }
  return val;
}

}  // namespace platform
}  // namespace paddle