compute_primitives.h 13.9 KB
Newer Older
F
Feng Xing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

N
niuliling123 已提交
17 18 19 20 21 22 23
#ifdef PADDLE_WITH_CUDA
#include <cuda_fp16.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#endif

24
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
25
#include "paddle/pten/common/float16.h"
N
niuliling123 已提交
26

27 28
namespace pten {
namespace kps {
N
niuliling123 已提交
29 30
namespace details {

31
#ifdef __HIPCC__
32
constexpr int kReduceMaxThread = 256;
33 34
constexpr int kWarpSize = 64;
#else
35
constexpr int kReduceMaxThread = 128;
36 37 38
constexpr int kWarpSize = 32;
#endif

39 40
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
41 42
enum ReduceMode { kGlobalMode, kLocalMode };

N
niuliling123 已提交
43 44 45 46 47 48 49
template <typename T>
class MPTypeTrait {
 public:
  using Type = T;
};

template <>
50
class MPTypeTrait<pten::dtype::float16> {
N
niuliling123 已提交
51 52 53 54
 public:
  using Type = float;
};

55
/**
56 57
 * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
 * memory.
58 59 60 61
 */
__device__ __forceinline__ int SharedMemoryIndex(int index) {
  return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}
N
niuliling123 已提交
62

63 64 65 66 67 68 69
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
N
niuliling123 已提交
70
  }
71 72
  return val;
}
N
niuliling123 已提交
73

74 75 76 77 78 79 80 81 82 83 84
/* e.g.
 * |---------block---------|
 * |warp0|warp1|warp2|warp3|
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory
 *   \    \    /     /        ---->3. Load the result above from shared memory
 *        res                         to warp0 and process the second WarpReduce
 */

/**
85
 * @brief BlockXReduce reduce along blockDim.x.
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
  __syncthreads();
  using details::kWarpSize;
  __shared__ T shared[2 * kWarpSize];
  int block_dim_x = blockDim.x;
  if (blockDim.x > kWarpSize) {
    block_dim_x = blockDim.x / kWarpSize;
    int lane = threadIdx.x % kWarpSize;
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    int wid = tid / kWarpSize;
    int bid = threadIdx.y;
    val = WarpReduce(val, reducer);
    if (lane == 0) {
      shared[wid] = val;
    }
    __syncthreads();
    val = shared[bid * block_dim_x + lane];
N
niuliling123 已提交
105
  }
106 107 108 109 110 111 112 113 114 115 116

  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, true);
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {
    T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
    val = reducer(val, temp);
  }
  return val;
}

/**
117
 * @brief BlockYReduce reduce along blockDim.y.
118 119 120
 */
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
121
  __shared__ T shared_memory[1024];
122 123 124 125 126 127 128 129 130
  shared_memory[SharedMemoryIndex(0)] = val;
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
    __syncthreads();
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
      T temp = shared_memory[SharedMemoryIndex(stride)];
      val = reducer(val, temp);
    }
    shared_memory[SharedMemoryIndex(0)] = val;
  }
131 132
  __syncthreads();
  return shared_memory[threadIdx.x];
133 134 135
}

}  // namespace details
N
niuliling123 已提交
136

137
/**
138
 * @brief Perform unary calculation according to OpFunc. Shape of input and
139 140 141
 * output are the same.
 *
 * @template paraments
142 143
 * InT: The data type of in.
 * OutT: The data type of out.
144 145 146
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
147
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
148 149
 * OpFunc: Compute functor which has an operator() as following:
 *     template <typename InT, typename OutT>
150
 *     struct XxxFunctor {
151
 *       HOSTDEVICE OutT operator()(const InT& a) const {
152 153 154
 *         return ...;
 *       }
 *     };
155 156 157 158 159
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
160
 */
161 162 163 164 165
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
166
          class OpFunc>
167 168
__device__ __forceinline__ void ElementwiseUnary(OutT* out,
                                                 const InT* in,
169 170 171 172 173 174
                                                 OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute(in[idx]));
  }
}
N
niuliling123 已提交
175 176

/**
177
 * @brief Binary calculation according to OpFunc. Shape of The input and output
178 179 180
 * are the same.
 *
 * @template paraments
181 182 183 184
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
 * NX: The number of data columns computed by each thread.
 * NY: The number of data rows computed by each thread.
185
 * BlockSize: Identifies the current device thread index method. For GPU,
186
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
187
 * OpFunc: Compute functor which has an operator() as following:
188
 *     template <typename InT>
189
 *     struct XxxFunctor {
190
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
191 192 193
 *         return ...;
 *       }
 *     };
194 195 196 197 198
 *
 * @param:
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
199
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
200
 */
201 202 203 204 205
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
206
          class OpFunc>
207 208
__device__ __forceinline__ void ElementwiseBinary(OutT* out,
                                                  const InT* in1,
209
                                                  const InT* in2,
N
niuliling123 已提交
210 211 212
                                                  OpFunc compute) {
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
213
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx]));
N
niuliling123 已提交
214 215 216 217
  }
}

/**
218
 * @brief Ternary calculation according to OpFunc. Shape of input and output
219 220 221
 * are the same.
 *
 * @template paraments
222 223
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
224 225 226
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
227
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
228
 * OpFunc: Compute functor which has an operator() as following
229
 *     template <typename InT>
230
 *     struct XxxFunctor {
231
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
232
 * const {
233 234 235
 *         return ...;
 *       }
 *     };
236 237 238 239 240 241
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * NY.
 * in2: The register pointer of second input, size is NX * NY.
 * in3: The register pointer of third input, size is NX * NY.
242
 * compute: Compute function which was declared like OpFunc<InT>().
N
niuliling123 已提交
243
 */
244 245 246 247 248
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
249
          class OpFunc>
250 251
__device__ __forceinline__ void ElementwiseTernary(
    OutT* out, const InT* in1, const InT* in2, const InT* in3, OpFunc compute) {
N
niuliling123 已提交
252 253
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
254 255 256 257 258
    out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
  }
}

/**
259 260
 * @brief Multivariate calculation according to OpFunc. Shape of inputs and
 * output are the same.
261 262
 *
 * @template paraments
263 264
 * InT: The data type of in1, in2 and in3.
 * OutT: The data type of out.
265 266 267
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
268 269
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
 * Arity: The size of ins.
270
 * OpFunc: Compute functor which has an operator() as following:
271
 *     template <typename InT>
272
 *     struct XxxFunctor {
273
 *       HOSTDEVICE InT operator()(const InT* args) const {
274 275 276
 *         return ...;
 *       }
 *     };
277 278 279
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
280 281
 * ins: A pointers of array consisting of multiple inputs.
 * compute: Compute function which was declared like OpFunc<InT>().
282
 */
283 284 285 286 287 288
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
          int Arity,
289
          class OpFunc>
290 291
__device__ __forceinline__ void ElementwiseAny(OutT* out,
                                               InT (*ins)[NX * NY],
292
                                               OpFunc compute) {
293
  InT args[Arity];
294 295 296 297 298 299
#pragma unroll
  for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
      args[j] = ins[j][idx];
    }
300
    out[idx] = static_cast<OutT>(compute(args));
N
niuliling123 已提交
301
  }
F
Feng Xing 已提交
302
}
N
niuliling123 已提交
303 304

/**
305 306 307 308 309
 * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
 * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
 * shape is [NY, NX].
 *
 * @template paraments
310 311
 * InT: The data type of in1 and in2.
 * OutT: The data type of out.
312 313 314
 * NX: The number of data columns loaded by each thread.
 * NY: The number of data rows loaded by each thread.
 * BlockSize: Identifies the current device thread index method. For GPU,
315
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
316 317 318 319 320 321 322 323
 * OpFunc: Compute functor which has an operator() as following
 *     template <typename InT, typename OutT>
 *     struct XxxFunctor {
 *       HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 *
324
 * @param
325 326 327 328
 * out: The register pointer of out, the size is NX * NY.
 * in1: The register pointer of fist input, size is NX * 1.
 * in2: The register pointer of second input, size is NX * NY.
 * compute: Compute function which was declared like OpFunc<InT, OutT>().
N
niuliling123 已提交
329
 */
330 331 332 333 334
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
N
niuliling123 已提交
335
          class OpFunc>
336 337 338 339
__device__ __forceinline__ void CycleBinary(OutT* out,
                                            const InT* in1,
                                            const InT* in2,
                                            OpFunc compute) {
N
niuliling123 已提交
340 341 342 343 344 345 346 347
#pragma unroll
  for (int idx = 0; idx < NX; idx++) {
#pragma unroll
    for (int idy = 0; idy < NY; idy++) {
      out[idx + idy * NX] =
          static_cast<OutT>(compute(in1[idx], in2[idx + idy * NX]));
    }
  }
F
Feng Xing 已提交
348
}
N
niuliling123 已提交
349

350
/**
351 352 353 354 355 356 357 358 359 360
 * @brief The Reduce provides collective methods for computing a parallel
 * reduction of items partitioned across a CUDA block and intra thread. When
 * ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
 * kGlobalMode, use shared memory to reduce between threads.
 *
 * @template paraments
 * T: The type of data.
 * NX: The number of data continuously loaded by each thread.
 * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
 * BlockSize: Identifies the current device thread index method. For GPU,
361
 * threadIdx.x is used as the thread index. Currently only GPU was supported.
362 363 364 365 366 367 368 369 370 371 372 373 374 375
 * ReduceFunctor: Compute functor which has an operator() as following
 *     template <typename InT>
 *     struct ReduceFunctor {
 *       HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
 *         return ...;
 *       }
 *     };
 * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
 *
 * @param
 * out: The register pointer of out, the size is NX * NY.
 * in: The register pointer of in, the size is NX * NY.
 * reducer: Compute function which was declared like ReduceFunctor<InT>().
 * reduce_last_dim: if the last dim gets involved in reduction.
376
 */
377 378 379 380 381
template <typename T,
          int NX,
          int NY,
          int BlockSize,
          class ReduceFunctor,
382
          details::ReduceMode Mode>
383 384
__device__ __forceinline__ void Reduce(T* out,
                                       const T* in,
385
                                       ReduceFunctor reducer,
386 387 388 389 390 391 392 393 394 395
                                       bool reduce_last_dim) {
  int block_index = blockDim.y;

  if (Mode == details::ReduceMode::kGlobalMode) {
    bool block_reduce_y = (!reduce_last_dim) && (block_index > 1);
    // when reduce is not required for the last dim, and reduce num has been
    // split into multiple threads
    if (block_reduce_y) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockdim.y
396
        out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
397 398 399 400 401 402 403
      }
    }

    // when last dimension need to be reduced
    if (reduce_last_dim) {
#pragma unroll
      for (int i = 0; i < NY * NX; i++) {  // reduce along blockDim.x
404
        out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
405 406 407 408 409 410 411 412 413 414 415 416 417
      }
    }
  } else {  // else  kLocalMode
#pragma unroll
    for (int i = 0; i < NY; ++i) {
#pragma unroll
      for (int j = 0; j < NX; ++j) {
        out[i] = reducer(out[i], in[i * NX + j]);
      }
    }
  }
}

418 419 420 421 422 423
template <typename InT,
          typename OutT,
          int NX,
          int NY,
          int BlockSize,
          class OpFunc>
424
__device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
425 426 427 428 429 430
#pragma unroll
  for (int idx = 0; idx < NX * NY; idx++) {
    out[idx] = static_cast<OutT>(compute());
  }
}

431 432
}  // namespace kps
}  // namespace pten