elementwise_op_impl.cu.h 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include "paddle/fluid/framework/tensor.h"
17
#include "paddle/fluid/platform/cuda_device_function.h"
18
#include "paddle/fluid/platform/fast_divmod.h"
19 20 21 22 23 24 25

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

26 27 28 29 30
namespace paddle {
namespace operators {

enum ElementwiseType { kUnary = 1, kBinary = 2 };

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
                            int64_t numel, int vec_size) {
  int threads = ELEMENTWISE_BLOCK_SIZE;
  int sm_count = ctx.GetSMCount();
  int active_threads_num = numel / vec_size;
  if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
  } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  return std::max(64, threads);
}

/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
61 62
template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
63 64
  constexpr int max_load_bits = 128;
  int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
65
  uint64_t address = reinterpret_cast<uint64_t>(pointer);
66 67
  constexpr int vec8 =
      std::alignment_of<CudaAlignedVector<T, 8>>::value;  // NOLINT
68 69 70 71
  constexpr int vec4 =
      std::alignment_of<CudaAlignedVector<T, 4>>::value;  // NOLINT
  constexpr int vec2 =
      std::alignment_of<CudaAlignedVector<T, 2>>::value;  // NOLINT
72 73 74 75 76 77 78 79 80 81
  if (address % vec8 == 0) {
    /*
    * Currently, decide to deal with no more than 4 data once while adopting
    * vectorization load/store, if performance test shows that dealing with
    * 8 data once in vectorization load/store does get optimized, return code
    * below can be changed into " return std::min(8, valid_vec_size); " .
    */
    return std::min(4, valid_vec_size);
  } else if (address % vec4 == 0) {
    return std::min(4, valid_vec_size);
82
  } else if (address % vec2 == 0) {
83 84 85
    return std::min(2, valid_vec_size);
  } else {
    return 1;
86 87 88
  }
}

89
template <typename InT, typename OutT>
90 91 92 93 94
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
                      const std::vector<framework::Tensor *> &outs) {
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
    vec_size =
95
        std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>()));
96 97 98
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
    vec_size =
99
        std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<OutT>()));
100 101 102 103
  }
  return vec_size;
}

104
template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
105
struct ElementwiseDataWrapper {
106 107 108 109 110
  OutT *out;
  const InT *in0;
  const InT *in1;
  __device__ ElementwiseDataWrapper(OutT *out, const InT *in0,
                                    const InT *in1 = nullptr)
111 112
      : out(out), in0(in0), in1(in1) {}

113 114
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
115

116 117
  inline __device__ void load_vector(InVecType args[], int idx) {
    const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0);
118 119
    args[0] = x_vec[idx];
    if (ET == ElementwiseType::kBinary) {
120
      const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1);
121 122 123 124
      args[1] = y_vec[idx];
    }
  }

125
  inline __device__ void load_scalar(InT args[], int idx) {
126 127 128 129 130 131
    args[0] = in0[idx];
    if (ET == ElementwiseType::kBinary) {
      args[1] = in1[idx];
    }
  }

132 133
  inline __device__ void store_vector(OutVecType res, int idx) {
    OutVecType *out_vec = reinterpret_cast<OutVecType *>(out);
134 135 136
    out_vec[idx] = res;
  }

137
  inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; }
138 139
};

140 141
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
142
__device__ inline void VectorizedKernelImpl(
143 144 145 146 147 148 149
    ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
    int tid) {
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
  InVecType ins_vec[ET];
  OutVecType out_vec;
  InT *ins_ptr[ET];
150
  InT ins[ET];
151 152
#pragma unroll
  for (int i = 0; i < ET; ++i) {
153
    ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
154 155 156 157 158 159 160 161 162 163 164
  }
  // load
  data.load_vector(ins_vec, tid);

// compute
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll
    for (int j = 0; j < ET; ++j) {
      ins[j] = ins_ptr[j][i];
    }
165
    out_vec.val[i] = func(ins);
166 167 168 169 170
  }
  // store
  data.store_vector(out_vec, tid);
}

171 172
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
173
__device__ inline void ScalarKernelImpl(
174 175 176 177
    ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
    int start, int remain) {
  InT ins[ET];
  OutT out;
178 179 180 181 182 183 184 185 186 187 188 189

  for (int i = 0; i < remain; ++i) {
    int idx = start + i;
    // load
    data.load_scalar(ins, idx);
    // compute
    out = func(ins);
    // store
    data.store_scalar(out, idx);
  }
}

190 191 192 193 194
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
__global__ void VectorizedKernel(const InT *__restrict__ in0,
                                 const InT *__restrict__ in1, OutT *out,
                                 int size, Functor func) {
195 196 197
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int remain = size - VecSize * tid;
  remain = remain > 0 ? remain : 0;
198
  auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1);
199
  if (remain >= VecSize) {
200
    VectorizedKernelImpl(data, func, tid);
201
  } else {
202
    ScalarKernelImpl(data, func, tid * VecSize, remain);
203 204 205
  }
}

206 207 208
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
__global__ void ScalarKernel(const InT *__restrict__ in0,
                             const InT *__restrict__ in1, OutT *out, int size,
209
                             Functor func) {
210
  auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1);
211 212
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int remain = tid < size ? 1 : 0;
213
  ScalarKernelImpl(data, func, tid, remain);
214 215
}

216
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
217
void LaunchSameDimsElementwiseCudaKernel(
218 219 220 221 222
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, Functor func) {
  // calculate the max vec_size for all ins and outs
  auto size = ins[0]->numel();
223
  int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
224
  int block_size = GetThreadsConfig(ctx, size, vec_size);
225 226
  int grid_size =
      ((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
227 228 229 230
  const InT *in0 = ins[0]->data<InT>();
  const InT *in1 =
      (ET == ElementwiseType::kBinary) ? ins[1]->data<InT>() : nullptr;
  OutT *out = (*outs)[0]->data<OutT>();
231 232
  // cuda kernel
  auto stream = ctx.stream();
233

234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
  switch (vec_size) {
    case 4:
      VectorizedKernel<ET, 4><<<grid_size, block_size, 0, stream>>>(
          in0, in1, out, size, func);
      break;
    case 2:
      VectorizedKernel<ET, 2><<<grid_size, block_size, 0, stream>>>(
          in0, in1, out, size, func);
      break;
    case 1:
      ScalarKernel<ET><<<grid_size, block_size, 0, stream>>>(in0, in1, out,
                                                             size, func);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
  }
}

}  // namespace operators
}  // namespace paddle