compute_primitives.h 8.6 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24 25
// #include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h"
N
niuliling123 已提交
26 27
#include "paddle/fluid/platform/float16.h"

F
Feng Xing 已提交
28 29
namespace paddle {
namespace operators {
N
niuliling123 已提交
30 31 32
namespace kernel_primitives {
namespace details {

33 34 35 36 37 38 39 40 41 42
#ifdef __HIPCC__
constexpr int kMaxThread = 256;
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kWarpSize = 32;
#endif

enum ReduceMode { kGlobalMode, kLocalMode };

N
niuliling123 已提交
43 44 45 46 47 48 49 50 51 52 53 54
template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};

template <>
class MPTypeTrait<platform::float16> {
 public:
  using Type = float;
};

55 56 57 58 59 60 61
/**
 * @brief will be used in BlockYReduce, get the index of reduce_num in shared
 * memory
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
62

63 64 65 66 67 68 69
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
N
niuliling123 已提交
70
  }
71 72
  return val;
}
N
niuliling123 已提交
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
 * @brief BlockXReduce reduce along blockDim.x
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
105
  }
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/**
 * @brief BlockYReduce reduce along blockDim.y
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
  __shared__ T shared_memory[details::kMaxThread];
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
  return val;
}

}  // namespace details
N
niuliling123 已提交
135 136 137 138

/*************************** Compute Function****************************/

/**
139
 * @brief binary function, in1 and in2 have same shape
N
niuliling123 已提交
140
 * @param:
141 142 143 144 145 146
 * T: data type of in1, in2
 * OutT: data type of out
 * NX: the cols of in1, in2
 * NY: the rows of in1, in2
 * BlockSize: the config of this device
 * OpFunc: compute functor eg: in1 + in2, in1 - in2
N
niuliling123 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
 */
template <typename T, typename OutT, int NX, int NY, int BlockSize,
          class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
                                                  const T* in2,
                                                  OpFunc compute) {
  T args[2];
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
    args[0] = in1[idx];
    args[1] = in2[idx];
    out[idx] = static_cast<OutT>(compute(args));
  }
}

/**
163
 * @brief ternary function, in1, in2 and in3 have same shape
N
niuliling123 已提交
164
 * @param:
165 166 167 168 169 170
 * T: data type of in1, in2, in3
 * OutT: data type of out
 * NX: the cols of in1, in2
 * NY: the rows of in1, in2
 * BlockSize: the config of this device
 * OpFunc: compute functor eg: out = in1 * in2 + in3
N
niuliling123 已提交
171 172 173
 */
template <typename T, typename OutT, int NX, int NY, int BlockSize,
          class OpFunc>
174 175 176 177
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
                                                   const T* in2, const T* in3,
                                                   OpFunc compute) {
  T args[3];
N
niuliling123 已提交
178 179
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
180 181 182 183
    args[0] = in1[idx];
    args[1] = in2[idx];
    args[2] = in3[idx];
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
184
  }
F
Feng Xing 已提交
185
}
N
niuliling123 已提交
186 187

/**
188 189
 * @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
 * is [NY, NX], out's shape size is [NY, NX]
N
niuliling123 已提交
190
 * @param:
191 192 193 194 195 196
 * T: data type of in1, in2
 * OutT: data type of out
 * NX: the cols of in1, in2
 * NY: the rows of in1, in2
 * BlockSize: the config of this device
 * OpFunc: compute functor eg: in1 + in2, in1 - in2
N
niuliling123 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209
 */
template <typename T, typename OutT, int NX, int NY, int BlockSize,
          class OpFunc>
__device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
                                            const T* in2, OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
210
}
N
niuliling123 已提交
211 212

/**
213
 * @brief unary function
N
niuliling123 已提交
214
 * @param:
215 216 217 218 219 220
 * T: data type of in
 * OutT: data type of out
 * NX: the cols of in
 * NY: the rows of in
 * BlockSize: the config of this device
 * OpFunc: compute functor eg: relu, exp
N
niuliling123 已提交
221 222 223 224 225 226 227 228 229 230 231
 */
template <typename T, typename OutT, int NX, int NY, int BlockSize,
          class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in + idx));
  }
}

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
/**
 * @brief reduce function, in's shape size is [NX, NY].
 * If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
 * if ReduceMode == kGlobalMode then reduce between different threads, the
 * shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
 * split, BlockYReduce will be called. If reduce_last_dim is true and
 * reduce_num was split, BlockXReduce will be called
 * @typename:
 * T: data type of in
 * NX: the cols of in
 * NY: the rows of in
 * BlockSize: the config of this device
 * OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
 * @param:
 * reducer: reduce functor, eg: CustomSum<T>()
 * reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
 * true
 */
template <typename T, int NX, int NY, int BlockSize, class OpFunc,
          details::ReduceMode Mode>
__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
        out[i] = details::BlockYReduce<T, OpFunc>(out[i], reducer);
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
        out[i] = details::BlockXReduce<T, OpFunc>(out[i], reducer);
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

N
niuliling123 已提交
285 286 287
}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle