elementwise_op_impl.cu.h 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16 17
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
18
#include "paddle/fluid/platform/fast_divmod.h"
19 20 21 22 23 24 25

#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
namespace paddle {
namespace operators {

enum ElementwiseType { kUnary = 1, kBinary = 2 };

template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
  uint64_t address = reinterpret_cast<uint64_t>(pointer);
  constexpr int vec4 =
      std::alignment_of<CudaAlignedVector<T, 4>>::value;  // NOLINT
  constexpr int vec2 =
      std::alignment_of<CudaAlignedVector<T, 2>>::value;  // NOLINT
  if (address % vec4 == 0) {
    return 4;
  } else if (address % vec2 == 0) {
    return 2;
  }
  return 1;
}

46
template <typename InT, typename OutT>
47 48 49 50 51
int GetVectorizedSize(const std::vector<const framework::Tensor *> &ins,
                      const std::vector<framework::Tensor *> &outs) {
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
    vec_size =
52
        std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<InT>()));
53 54 55
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
    vec_size =
56
        std::min<int>(vec_size, GetVectorizedSizeImpl((*iter)->data<OutT>()));
57 58 59 60
  }
  return vec_size;
}

61
template <ElementwiseType ET, int VecSize, typename InT, typename OutT>
62
struct ElementwiseDataWrapper {
63 64 65 66 67
  OutT *out;
  const InT *in0;
  const InT *in1;
  __device__ ElementwiseDataWrapper(OutT *out, const InT *in0,
                                    const InT *in1 = nullptr)
68 69
      : out(out), in0(in0), in1(in1) {}

70 71
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
72

73 74
  inline __device__ void load_vector(InVecType args[], int idx) {
    const InVecType *x_vec = reinterpret_cast<const InVecType *>(in0);
75 76
    args[0] = x_vec[idx];
    if (ET == ElementwiseType::kBinary) {
77
      const InVecType *y_vec = reinterpret_cast<const InVecType *>(in1);
78 79 80 81
      args[1] = y_vec[idx];
    }
  }

82
  inline __device__ void load_scalar(InT args[], int idx) {
83 84 85 86 87 88
    args[0] = in0[idx];
    if (ET == ElementwiseType::kBinary) {
      args[1] = in1[idx];
    }
  }

89 90
  inline __device__ void store_vector(OutVecType res, int idx) {
    OutVecType *out_vec = reinterpret_cast<OutVecType *>(out);
91 92 93
    out_vec[idx] = res;
  }

94
  inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; }
95 96
};

97 98
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
99
__device__ void VectorizedKernelImpl(
100 101 102 103 104 105 106 107
    ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
    int tid) {
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
  InVecType ins_vec[ET];
  OutVecType out_vec;
  InT *ins_ptr[ET];
  OutT *out_ptr;
108 109
#pragma unroll
  for (int i = 0; i < ET; ++i) {
110
    ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
111
  }
112
  out_ptr = reinterpret_cast<OutT *>(&out_vec);
113 114 115 116 117 118 119

  // load
  data.load_vector(ins_vec, tid);

// compute
#pragma unroll
  for (int i = 0; i < VecSize; ++i) {
120
    InT ins[ET];
121 122 123 124 125 126 127 128 129 130 131
#pragma unroll
    for (int j = 0; j < ET; ++j) {
      ins[j] = ins_ptr[j][i];
    }
    out_ptr[i] = func(ins);
  }

  // store
  data.store_vector(out_vec, tid);
}

132 133 134 135 136 137 138
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
__device__ void ScalarKernelImpl(
    ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
    int start, int remain) {
  InT ins[ET];
  OutT out;
139 140 141 142 143 144 145 146 147 148 149 150

  for (int i = 0; i < remain; ++i) {
    int idx = start + i;
    // load
    data.load_scalar(ins, idx);
    // compute
    out = func(ins);
    // store
    data.store_scalar(out, idx);
  }
}

151 152 153 154 155
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
          typename Functor>
__global__ void VectorizedKernel(const InT *__restrict__ in0,
                                 const InT *__restrict__ in1, OutT *out,
                                 int size, Functor func) {
156 157 158
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int remain = size - VecSize * tid;
  remain = remain > 0 ? remain : 0;
159
  auto data = ElementwiseDataWrapper<ET, VecSize, InT, OutT>(out, in0, in1);
160
  if (remain >= VecSize) {
161
    VectorizedKernelImpl(data, func, tid);
162
  } else {
163
    ScalarKernelImpl(data, func, tid * VecSize, remain);
164 165 166
  }
}

167 168 169
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
__global__ void ScalarKernel(const InT *__restrict__ in0,
                             const InT *__restrict__ in1, OutT *out, int size,
170
                             Functor func) {
171
  auto data = ElementwiseDataWrapper<ET, 1, InT, OutT>(out, in0, in1);
172 173
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int remain = tid < size ? 1 : 0;
174
  ScalarKernelImpl(data, func, tid, remain);
175 176
}

177
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
178
void LaunchSameDimsElementwiseCudaKernel(
179 180 181 182 183
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, Functor func) {
  // calculate the max vec_size for all ins and outs
  auto size = ins[0]->numel();
184
  int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
185
  int block_size = ELEMENTWISE_BLOCK_SIZE;
186 187
  int grid_size =
      ((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
188 189 190 191
  const InT *in0 = ins[0]->data<InT>();
  const InT *in1 =
      (ET == ElementwiseType::kBinary) ? ins[1]->data<InT>() : nullptr;
  OutT *out = (*outs)[0]->data<OutT>();
192 193
  // cuda kernel
  auto stream = ctx.stream();
194

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
  switch (vec_size) {
    case 4:
      VectorizedKernel<ET, 4><<<grid_size, block_size, 0, stream>>>(
          in0, in1, out, size, func);
      break;
    case 2:
      VectorizedKernel<ET, 2><<<grid_size, block_size, 0, stream>>>(
          in0, in1, out, size, func);
      break;
    case 1:
      ScalarKernel<ET><<<grid_size, block_size, 0, stream>>>(in0, in1, out,
                                                             size, func);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
  }
}

}  // namespace operators
}  // namespace paddle