elementwise_op_broadcast.cu.h 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19 20 21
namespace paddle {
namespace operators {

22 23
#define MAX_INPUT_NUM 3  // the max num of ET for BroadcacstConfig

24 25
namespace kps = paddle::operators::kernel_primitives;

26 27 28 29 30 31 32 33 34
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
35 36
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
37 38 39 40 41 42 43 44 45 46 47 48
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
49 50 51 52
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
53 54 55 56 57 58 59 60 61 62 63
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
64 65 66 67
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
68 69 70 71 72 73 74 75 76 77
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
78
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
147
    MergeDimensions<MergeFunctor>(merge_ptr, N);
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
163
    MergeDimensions<MergeFunctor>(merge_ptr, N);
164 165 166 167
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

168
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false>
169 170
__device__ __forceinline__ void LoadData(
    T *dst, const T *__restrict__ src, uint32_t block_offset,
171
    const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num,
172 173 174 175
    bool need_broadcast) {
  // numel : whole num of output
  // num: how many data will be deal with in this time
  if (need_broadcast) {
176 177
    kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>(
        dst, src, block_offset, config, numel, 1, 1);
178 179
  } else {
    kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
180
  }
181
}
182

183 184
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
          int VecSize, typename Functor, bool IsBoundary = false>
185
__device__ void DealSegment(
186 187 188 189
    const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
    const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
    const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
                           MAX_INPUT_NUM> &configlists,
190
    int num, Functor func) {
191
  InT args[ET][VecSize];
192 193
  OutT result[VecSize];
  int block_offset = blockIdx.x * blockDim.x * VecSize;
194
// load
195
#pragma unroll
196
  for (int i = 0; i < ET; i++) {
197
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
198 199 200
    LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
                                                  configlists[i], numel, num,
                                                  use_broadcast[i]);
201
  }
202 203 204 205 206 207 208 209 210 211 212 213
  // compute
  if (ET == kUnary) {
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                             func);
  } else if (ET == kBinary) {
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                              args[1], func);
  } else {
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
  }
  // compute
214 215 216
  kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
                                                  num);
}
217

218 219
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
          int VecSize, typename Functor>
220
__global__ void BroadcastKernel(
221 222 223 224
    framework::Array<const InT *__restrict__, ET> in, OutT *out,
    framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
    framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
        configlists,
225 226 227 228 229
    int main_tid, int tail_tid, Functor func) {
  int block_offset = blockIdx.x * blockDim.x * VecSize;
  // data offset of this block
  if (blockIdx.x < main_tid) {
    int num = blockDim.x * VecSize;  // blockIdx.x < main_tid
230 231
    DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
        in, out, use_broadcast, numel, configlists, num, func);
232 233
  } else {  // reminder
    int num = tail_tid;
234 235
    DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
        in, out, use_broadcast, numel, configlists, num, func);
236 237 238
  }
}

239 240
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          int Size, typename Functor>
241 242 243 244 245 246 247
void LaunchKernel(const platform::CUDADeviceContext &ctx,
                  const std::vector<const framework::Tensor *> &ins,
                  framework::Tensor *out, Functor func,
                  DimensionsTransform merge_dims) {
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
248

249 250 251 252 253
  int main_tid = numel / (VecSize * threads);
  int tail_tid = numel % (VecSize * threads);
  auto stream = ctx.stream();
  OutT *out_data = out->data<OutT>();

254 255 256 257
  framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
      configlists;
  framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
  framework::Array<const InT *__restrict__, ET> ins_data;
258

259
  for (int i = 0; i < ET; i++) {
260 261 262 263 264 265
    use_broadcast[i] = (ins[i]->numel() != numel);
    ins_data[i] = ins[i]->data<InT>();
    if (use_broadcast[i]) {
      // get the broadcast config,
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
266
      configlists[i] = kps::details::BroadcastConfig<Size>(
267
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
268 269 270
    }
  }

271 272 273
  BroadcastKernel<ET, InT, OutT, Size, VecSize,
                  Functor><<<blocks, threads, 0, stream>>>(
      ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
274
      func);
275 276
}

277 278 279
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          typename Functor>
void LaunchBroadcastKernelForDifferentDimSize(
280 281
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
282
    int axis, Functor func) {
283
  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
284 285 286 287
#define DIM_SIZE(size)                                                       \
  case size: {                                                               \
    LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
                                                        merge_dims);         \
288
  } break;
289 290

  switch (merge_dims.dim_size) {
291 292 293 294 295 296 297 298
    DIM_SIZE(1);
    DIM_SIZE(2);
    DIM_SIZE(3);
    DIM_SIZE(4);
    DIM_SIZE(5);
    DIM_SIZE(6);
    DIM_SIZE(7);
    DIM_SIZE(8);
299
  }
300
#undef DIM_SIZE
301 302
}

303
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
304 305
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
306 307
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
308
  PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
309
                    platform::errors::InvalidArgument(
310 311 312
                        "Currently, only Support binary calculation, "
                        "but received %d input tensors.\n",
                        static_cast<int>(ET)));
313
  int in_vec_size = 4;
314
  framework::Tensor *out = (*outs)[0];
315
  for (auto *in : ins) {
316
    auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
317 318 319
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
320
  int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
321 322 323 324
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
325 326
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
                                                                 axis, func);
327 328 329
      break;
    }
    case 2: {
330 331
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
                                                                 axis, func);
332 333 334
      break;
    }
    case 1: {
335 336
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
                                                                 axis, func);
337 338 339
      break;
    }
    default: {
340 341
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
342 343 344 345 346
      break;
    }
  }
}

347
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
348
void LaunchElementwiseCudaKernel(
349
    const platform::CUDADeviceContext &cuda_ctx,
350
    const std::vector<const framework::Tensor *> &ins,
351
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
352
  std::vector<int> dims_size;
353 354 355
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
356
    dims_size.emplace_back(in->dims().size());
357
  }
358

359
  if (no_broadcast_flag) {
360 361
    LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                       func);
362
  } else {
363 364 365 366
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
367 368
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                        axis, func);
369 370 371
  }
}

372 373
#undef MAX_INPUT_NUM

374 375
}  // namespace operators
}  // namespace paddle