elementwise_op_impl.cu.h 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15 16
#pragma once

17
#include "paddle/fluid/framework/tensor.h"
18
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19 20
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
21 22 23 24 25 26 27

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

28 29 30
namespace paddle {
namespace operators {

31
namespace kps = paddle::operators::kernel_primitives;
32 33

enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
                            int64_t numel, int vec_size) {
  int threads = ELEMENTWISE_BLOCK_SIZE;
  int sm_count = ctx.GetSMCount();
  int active_threads_num = numel / vec_size;
  if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
  } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  return std::max(64, threads);
}

59
template <typename InT, typename OutT>
60 61 62
int GetVectorizedSizeForTensors(
    const std::vector<const framework::Tensor *> &ins,
    const std::vector<framework::Tensor *> &outs) {
63 64
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
65 66
    vec_size = std::min<int>(vec_size,
                             platform::GetVectorizedSize((*iter)->data<InT>()));
67 68
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
69 70
    vec_size = std::min<int>(
        vec_size, platform::GetVectorizedSize((*iter)->data<OutT>()));
71 72 73 74
  }
  return vec_size;
}

75 76 77 78 79 80 81 82 83 84 85 86 87
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity,
          bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
  __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
                                    OutT *result);
};

template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
  __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(result, args,
                                                                  func);
88
  }
89
};
90

91 92 93 94
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
  __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
                                    OutT *result) {
95 96
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                             func);
97 98 99 100 101 102 103
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
  __device__ inline OutT operator()(Functor func, InT (*args)[VecSize],
                                    OutT *result) {
104 105
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                              args[1], func);
106 107 108 109 110 111
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
  __device__ inline OutT operator()(Functor func, InT **args, OutT *result) {
112 113
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
114
  }
115 116 117 118 119 120 121 122 123
};

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
          bool IsBoundary>
__device__ void DealSegment(
    const framework::Array<const InT *__restrict__, Arity> &in, OutT *out,
    int num, Functor func) {
  InT args[Arity][VecSize];
  OutT result[VecSize];
124

125 126 127 128 129 130 131 132 133 134 135 136 137
  int data_offset = VecSize * blockIdx.x * blockDim.x;

#pragma unroll
  for (int i = 0; i < Arity; i++) {
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
    kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
                                                  num);
  }

  const bool kCallElementwiseAny =
      platform::FunctionTraits<Functor>::has_pointer_args;
  ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
                             kCallElementwiseAny>()(func, args, result);
138 139
  kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
                                                  num);
140 141
}

142
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
143
__global__ void ElementVectorizeKernel(
144
    framework::Array<const InT *__restrict__, Arity> ins, OutT *out, int size,
145 146 147 148 149
    Functor func) {
  int data_offset = VecSize * blockIdx.x * blockDim.x;
  int num = size - data_offset;
  // the num this time have to deal with
  if (VecSize * blockDim.x > num) {  // reminder segment
150
    DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
151
  } else {  // complete segment
152
    DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
153 154 155
  }
}

156
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
157 158 159 160 161 162 163 164 165 166
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
                           const std::vector<const framework::Tensor *> &ins,
                           std::vector<framework::Tensor *> *outs,
                           Functor func) {
  auto numel = ins[0]->numel();
  int block_size = GetThreadsConfig(ctx, numel, VecSize);
  int grid_size =
      ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;

  auto stream = ctx.stream();
167 168 169 170
  OutT *out_data = (*outs)[0]->data<OutT>();
  framework::Array<const InT *__restrict__, Arity> ins_data;
  for (int i = 0; i < Arity; i++) {
    ins_data[i] = ins[i]->data<InT>();
171
  }
172 173 174
  ElementVectorizeKernel<InT, OutT, Functor, Arity,
                         VecSize><<<grid_size, block_size, 0, stream>>>(
      ins_data, out_data, numel, func);
175 176
}

177
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
178
void LaunchSameDimsElementwiseCudaKernel(
179 180 181
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, Functor func) {
182 183 184 185 186 187 188 189 190 191
  using Traits = platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
  PADDLE_ENFORCE_EQ(ins.size(), kArity,
                    platform::errors::InvalidArgument(
                        "The number of inputs is expected to be equal to the "
                        "arity of functor. But recieved: the number of inputs "
                        "is %d, the arity of functor is %d.",
                        ins.size(), kArity));

192
  // calculate the max vec_size for all ins and outs
193
  int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
194
  switch (vec_size) {
195
    case 4:
196 197
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(ctx, ins, outs,
                                                           func);
198
      break;
199
    case 2:
200 201
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(ctx, ins, outs,
                                                           func);
202
      break;
203
    case 1:
204 205
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(ctx, ins, outs,
                                                           func);
206
      break;
207
    default: {
208 209 210
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
211
    }
212 213 214 215 216
  }
}

}  // namespace operators
}  // namespace paddle