elementwise_op_impl.cu.h 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include "paddle/fluid/framework/tensor.h"
17
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
18 19
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
20 21 22 23 24 25 26

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

27 28 29
namespace paddle {
namespace operators {

30
namespace kps = paddle::operators::kernel_primitives;
31
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
                            int64_t numel, int vec_size) {
  int threads = ELEMENTWISE_BLOCK_SIZE;
  int sm_count = ctx.GetSMCount();
  int active_threads_num = numel / vec_size;
  if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
  } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  return std::max(64, threads);
}

57
template <typename InT, typename OutT>
58 59
int GetVectorizedSizeForIO(const std::vector<const framework::Tensor *> &ins,
                           const std::vector<framework::Tensor *> &outs) {
60 61
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
62 63
    vec_size = std::min<int>(vec_size,
                             platform::GetVectorizedSize((*iter)->data<InT>()));
64 65
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
66 67
    vec_size = std::min<int>(
        vec_size, platform::GetVectorizedSize((*iter)->data<OutT>()));
68 69 70 71
  }
  return vec_size;
}

72 73 74 75 76 77 78 79 80 81 82 83 84 85
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor, bool IsBoundary>
__device__ void DealSegment(
    const framework::Array<const InT *__restrict__, ET> &in, OutT *out, int num,
    Functor func) {
  int data_offset = VecSize * blockIdx.x * blockDim.x;
  InT args[ET][VecSize];
  OutT result[VecSize];
// load data
#pragma unroll
  for (int i = 0; i < ET; i++) {
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
    kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
                                                  num);
86
  }
87

88 89
  // compute
  if (ET == kUnary) {
90 91
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                             func);
92
  } else if (ET == kBinary) {
93 94
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                              args[1], func);
95
  } else {
96 97
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
98 99
  }

100
  // store
101 102
  kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
                                                  num);
103 104
}

105 106
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
107
__global__ void ElementVectorizeKernel(
108
    framework::Array<const InT *__restrict__, ET> in, OutT *out, int size,
109 110 111 112 113
    Functor func) {
  int data_offset = VecSize * blockIdx.x * blockDim.x;
  int num = size - data_offset;
  // the num this time have to deal with
  if (VecSize * blockDim.x > num) {  // reminder segment
114
    DealSegment<ET, VecSize, InT, OutT, Functor, true>(in, out, num, func);
115
  } else {  // complete segment
116
    DealSegment<ET, VecSize, InT, OutT, Functor, false>(in, out, num, func);
117 118 119
  }
}

120 121
template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
          int VecSize>
122 123 124 125 126 127 128 129 130 131
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
                           const std::vector<const framework::Tensor *> &ins,
                           std::vector<framework::Tensor *> *outs,
                           Functor func) {
  auto numel = ins[0]->numel();
  int block_size = GetThreadsConfig(ctx, numel, VecSize);
  int grid_size =
      ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;

  auto stream = ctx.stream();
132 133 134 135
  OutT *out = (*outs)[0]->data<OutT>();
  framework::Array<const InT *__restrict__, ET> in;
  for (int i = 0; i < ET; i++) {
    in[i] = ins[i]->data<InT>();
136
  }
137 138 139
  ElementVectorizeKernel<ET, VecSize, InT, OutT,
                         Functor><<<grid_size, block_size, 0, stream>>>(
      in, out, numel, func);
140 141
}

142
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
143
void LaunchSameDimsElementwiseCudaKernel(
144 145 146 147
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, Functor func) {
  // calculate the max vec_size for all ins and outs
148
  int vec_size = GetVectorizedSizeForIO<InT, OutT>(ins, *outs);
149
  switch (vec_size) {
150
    case 4:
151
      ElementwiseCudaKernel<ET, InT, OutT, Functor, 4>(ctx, ins, outs, func);
152
      break;
153
    case 2:
154
      ElementwiseCudaKernel<ET, InT, OutT, Functor, 2>(ctx, ins, outs, func);
155
      break;
156
    case 1:
157
      ElementwiseCudaKernel<ET, InT, OutT, Functor, 1>(ctx, ins, outs, func);
158
      break;
159
    default: {
160 161 162
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
163
    }
164 165 166 167 168
  }
}

}  // namespace operators
}  // namespace paddle