compute_primitives.h 13.3 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24
#include "paddle/fluid/platform/cuda_device_function.h"
N
niuliling123 已提交
25 26
#include "paddle/fluid/platform/float16.h"

F
Feng Xing 已提交
27 28
namespace paddle {
namespace operators {
N
niuliling123 已提交
29 30 31
namespace kernel_primitives {
namespace details {

32
#ifdef __HIPCC__
33
constexpr int kReduceMaxThread = 256;
34 35
constexpr int kWarpSize = 64;
#else
36
constexpr int kReduceMaxThread = 128;
37 38 39
constexpr int kWarpSize = 32;
#endif

40 41
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
42 43
enum ReduceMode { kGlobalMode, kLocalMode };

N
niuliling123 已提交
44 45 46 47 48 49 50 51 52 53 54 55
template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};

template <>
class MPTypeTrait<platform::float16> {
 public:
  using Type = float;
};

56
/**
57 58
 * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
 * memory.
59 60 61 62
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
63

64 65 66 67 68 69 70
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
N
niuliling123 已提交
71
  }
72 73
  return val;
}
N
niuliling123 已提交
74

75 76 77 78 79 80 81 82 83 84 85
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
86
 * @brief BlockXReduce reduce along blockDim.x.
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
106
  }
107 108 109 110 111 112 113 114 115 116 117

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/**
118
 * @brief BlockYReduce reduce along blockDim.y.
119 120 121
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
122
  __shared__ T shared_memory[details::kReduceMaxThread];
123 124 125 126 127 128 129 130 131 132 133 134 135
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
  return val;
}

}  // namespace details
N
niuliling123 已提交
136

137
/**
138
 * @brief Perform unary calculation according to OpFunc. Shape of input and
139 140 141
 * output are the same.
 *
 * @template paraments
142 143
 * InT: The data type of in.
 * OutT: The data type of out.
144 145 146
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
147
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
148 149
 * OpFunc: Compute functor which has an operator() as following:
 *     template <typename InT, typename OutT>
150
 *     struct XxxFunctor {
151
 *       HOSTDEVICE OutT operator()(const InT& a) const {
152 153 154
 *         return ...;
 *       }
 *     };
155 156 157 158 159
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
160
 */
161
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
162
          class OpFunc>
163
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
164 165 166 167 168 169
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in[idx]));
  }
}
N
niuliling123 已提交
170 171

/**
172
 * @brief Binary calculation according to OpFunc. Shape of The input and output
173 174 175
 * are the same.
 *
 * @template paraments
176 177 178 179
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns computed by each thread.
 * NY: The number of data rows computed by each thread.
180
 * BlockSize: Identifies the current device thread index method. For GPU,
181
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
182
 * OpFunc: Compute functor which has an operator() as following:
183
 *     template <typename InT>
184
 *     struct XxxFunctor {
185
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
186 187 188
 *         return ...;
 *       }
 *     };
189 190 191 192 193
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
194
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
195
 */
196
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
N
niuliling123 已提交
197
          class OpFunc>
198 199
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
                                                  const InT* in2,
N
niuliling123 已提交
200 201 202
                                                  OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
203
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
N
niuliling123 已提交
204 205 206 207
  }
}

/**
208
 * @brief Ternary calculation according to OpFunc. Shape of input and output
209 210 211
 * are the same.
 *
 * @template paraments
212 213
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
214 215 216
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
217
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
218
 * OpFunc: Compute functor which has an operator() as following
219
 *     template <typename InT>
220
 *     struct XxxFunctor {
221
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
222
 * const {
223 224 225
 *         return ...;
 *       }
 *     };
226 227 228 229 230 231
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
 * in3: The register pointer of third input, size is NX * NY.
232
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
233
 */
234
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
N
niuliling123 已提交
235
          class OpFunc>
236 237 238
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const InT* in1,
                                                   const InT* in2,
                                                   const InT* in3,
239
                                                   OpFunc compute) {
N
niuliling123 已提交
240 241
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
242 243 244 245 246
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
  }
}

/**
247 248
 * @brief Multivariate calculation according to OpFunc. Shape of inputs and
 * output are the same.
249 250
 *
 * @template paraments
251 252
 * InT: The data type of in1, in2 and in3.
 * OutT: The data type of out.
253 254 255
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
256 257
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Arity: The size of ins.
258
 * OpFunc: Compute functor which has an operator() as following:
259
 *     template <typename InT>
260
 *     struct XxxFunctor {
261
 *       HOSTDEVICE InT operator()(const InT* args) const {
262 263 264
 *         return ...;
 *       }
 *     };
265 266 267
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
268 269
 * ins: A pointers of array consisting of multiple inputs.
 * compute: Compute function which was declared like OpFunc<InT>().
270
 */
271
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity,
272
          class OpFunc>
273
__device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY],
274
                                               OpFunc compute) {
275
  InT args[Arity];
276 277 278 279 280 281
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
      args[j] = ins[j][idx];
    }
282
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
283
  }
F
Feng Xing 已提交
284
}
N
niuliling123 已提交
285 286

/**
287 288 289 290 291
 * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
 * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
 * shape is [NY, NX].
 *
 * @template paraments
292 293
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
294 295 296
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
297
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
298 299 300 301 302 303 304 305
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT, typename OutT>
 *     struct XxxFunctor {
 *       HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 *
306
 * @param
307 308 309 310
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * 1.
 * in2: The register pointer of second input, size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
N
niuliling123 已提交
311
 */
312
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
N
niuliling123 已提交
313
          class OpFunc>
314 315
__device__ __forceinline__ void CycleBinary(OutT* out, const InT* in1,
                                            const InT* in2, OpFunc compute) {
N
niuliling123 已提交
316 317 318 319 320 321 322 323
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
324
}
N
niuliling123 已提交
325

326
/**
327 328 329 330 331 332 333 334 335 336
 * @brief The Reduce provides collective methods for computing a parallel
 * reduction of items partitioned across a CUDA block and intra thread. When
 * ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
 * kGlobalMode, use shared memory to reduce between threads.
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
337
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
338 339 340 341 342 343 344 345 346 347 348 349 350 351
 * ReduceFunctor: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct ReduceFunctor {
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * reducer: Compute function which was declared like ReduceFunctor<InT>().
 * reduce_last_dim: if the last dim gets involved in reduction.
352
 */
353
template <typename T, int NX, int NY, int BlockSize, class ReduceFunctor,
354
          details::ReduceMode Mode>
355 356
__device__ __forceinline__ void Reduce(T* out, const T* in,
                                       ReduceFunctor reducer,
357 358 359 360 361 362 363 364 365 366
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
367
        out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
368 369 370 371 372 373 374
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
375
        out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
376 377 378 379 380 381 382 383 384 385 386 387 388
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

N
niuliling123 已提交
389 390 391
}  // namespace kernel_primitives
}  // namespace operators
}  // namespace paddle