elementwise_op_impl.cu.h 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include "paddle/fluid/framework/tensor.h"
17
#include "paddle/fluid/platform/cuda_device_function.h"
18
#include "paddle/fluid/platform/fast_divmod.h"
19 20 21 22 23 24 25

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

26 27 28
namespace paddle {
namespace operators {

29
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
30

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
                            int64_t numel, int vec_size) {
  int threads = ELEMENTWISE_BLOCK_SIZE;
  int sm_count = ctx.GetSMCount();
  int active_threads_num = numel / vec_size;
  if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
  } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  return std::max(64, threads);
}

55
template <typename InT, typename OutT>
56 57
int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
                           const std::vector<framework::Tensor *> &outs) {
58 59
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
60 61
    vec_size = std::min<int>(vec_size,
                             platform::GetVectorizedSize((*iter)->data<InT>()));
62 63
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
64 65
    vec_size = std::min<int>(
        vec_size, platform::GetVectorizedSize((*iter)->data<OutT>()));
66 67 68 69
  }
  return vec_size;
}

70
template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
71
struct ElementwiseDataWrapper {
72 73
  using InVecType = platform::AlignedVector<InT, VecSize>;
  using OutVecType = platform::AlignedVector<OutT, VecSize>;
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

  const InT *__restrict__ in_data[ET];
  OutT *out_data;
  uint32_t scalar_cal_offset;

  HOSTDEVICE ElementwiseDataWrapper(
      const std::vector<const framework::Tensor *> &ins,
      std::vector<framework::Tensor *> *outs, uint32_t scalar_cal_offset)
      : scalar_cal_offset(scalar_cal_offset) {
#pragma unroll
    for (int i = 0; i < ET; ++i) {
      in_data[i] = ins[i]->data<InT>();
    }
    out_data = (*outs)[0]->data<OutT>();
  }

  inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) {
#pragma unroll
    for (int i = 0; i < ET; ++i) {
      const InVecType *in_vec_data =
          reinterpret_cast<const InVecType *>(in_data[i]);
      vec_args[i] = in_vec_data[tid];
96 97 98
    }
  }

99 100 101 102
  inline __device__ void LoadScalarizedData(InT args[], int tid) {
#pragma unroll
    for (int i = 0; i < ET; ++i) {
      args[i] = in_data[i][tid + scalar_cal_offset];
103 104 105
    }
  }

106 107 108
  inline __device__ void StoreVectorizedData(OutVecType res, int tid) {
    OutVecType *out_vec = reinterpret_cast<OutVecType *>(out_data);
    out_vec[tid] = res;
109 110
  }

111 112 113
  inline __device__ void StoreScalarizedData(OutT res, int tid) {
    out_data[tid + scalar_cal_offset] = res;
  }
114 115
};

116 117 118 119
template <ElementwiseType ET, int VecSize, typename ElementwiseWrapper,
          typename InT, typename OutT, typename Functor>
__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data,
                                            Functor func, int tid) {
120 121
  using InVecType = platform::AlignedVector<InT, VecSize>;
  using OutVecType = platform::AlignedVector<OutT, VecSize>;
122 123 124
  InVecType ins_vec[ET];
  OutVecType out_vec;
  InT *ins_ptr[ET];
125
  InT ins[ET];
126 127
#pragma unroll
  for (int i = 0; i < ET; ++i) {
128
    ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
129 130
  }
  // load
131
  data.LoadVectorizedData(ins_vec, tid);
132 133 134 135 136 137 138 139

// compute
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll
    for (int j = 0; j < ET; ++j) {
      ins[j] = ins_ptr[j][i];
    }
140
    out_vec.val[i] = func(ins);
141 142
  }
  // store
143
  data.StoreVectorizedData(out_vec, tid);
144 145
}

146 147 148 149
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
          typename OutT, typename Functor>
__device__ inline void ScalarKernelImpl(ElementwiseWrapper data, Functor func,
                                        int tid) {
150 151
  InT ins[ET];
  OutT out;
152

153 154 155 156 157 158
  // load
  data.LoadScalarizedData(ins, tid);
  // compute
  out = func(ins);
  // store
  data.StoreScalarizedData(out, tid);
159 160
}

161 162 163 164
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
          typename OutT, int VecSize, typename Functor>
__global__ void VectorizedKernel(ElementwiseWrapper data, int main_tid,
                                 int tail_tid, Functor func) {
165
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
166 167 168 169 170 171 172 173

  if (tid < main_tid) {
    VectorizedKernelImpl<ET, VecSize, ElementwiseWrapper, InT, OutT, Functor>(
        data, func, tid);
  }
  if (tid < tail_tid) {
    ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func,
                                                                 tid);
174 175 176
  }
}

177 178 179
template <ElementwiseType ET, typename ElementwiseWrapper, typename InT,
          typename OutT, typename Functor>
__global__ void ScalarKernel(ElementwiseWrapper data, int numel, Functor func) {
180
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
181 182 183 184
  if (tid < numel) {
    ScalarKernelImpl<ET, ElementwiseWrapper, InT, OutT, Functor>(data, func,
                                                                 tid);
  }
185 186
}

187
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
188
void LaunchSameDimsElementwiseCudaKernel(
189 190 191 192
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, Functor func) {
  // calculate the max vec_size for all ins and outs
193 194 195
  auto numel = ins[0]->numel();
  int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs);
  int block_size = GetThreadsConfig(ctx, numel, vec_size);
196
  int grid_size =
197 198 199 200 201
      ((numel + vec_size - 1) / vec_size + block_size - 1) / block_size;
  int main_tid = numel / vec_size;
  int tail_tid = numel % vec_size;
  uint32_t vec_len = main_tid * vec_size;

202 203
  // cuda kernel
  auto stream = ctx.stream();
204

205
  switch (vec_size) {
206 207 208 209 210 211
    case 4: {
      auto data_wrapper =
          ElementwiseDataWrapper<ET, 4, InT, OutT>(ins, outs, vec_len);
      VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
                       4><<<grid_size, block_size, 0, stream>>>(
          data_wrapper, main_tid, tail_tid, func);
212
      break;
213 214 215 216 217 218 219
    }
    case 2: {
      auto data_wrapper =
          ElementwiseDataWrapper<ET, 2, InT, OutT>(ins, outs, vec_len);
      VectorizedKernel<ET, decltype(data_wrapper), InT, OutT,
                       2><<<grid_size, block_size, 0, stream>>>(
          data_wrapper, main_tid, tail_tid, func);
220
      break;
221 222 223 224 225 226 227
    }
    case 1: {
      auto data_wrapper =
          ElementwiseDataWrapper<ET, 1, InT, OutT>(ins, outs, 0);
      ScalarKernel<ET, decltype(data_wrapper), InT,
                   OutT><<<grid_size, block_size, 0, stream>>>(data_wrapper,
                                                               numel, func);
228
      break;
229 230
    }
    default: {
231 232 233
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
234
    }
235 236 237 238 239
  }
}

}  // namespace operators
}  // namespace paddle