elementwise_op_broadcast.cu.h 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19 20 21
namespace paddle {
namespace operators {

22 23 24 25
#define MAX_INPUT_NUM 3  // the max num of ET for BroadcacstConfig

namespace kps = paddle::operators::kernel_primitives;

26 27 28 29 30 31 32 33 34
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
35 36
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
78
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
147
    MergeDimensions<MergeFunctor>(merge_ptr, N);
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
163
    MergeDimensions<MergeFunctor>(merge_ptr, N);
164 165 166 167
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

168 169 170 171 172 173 174 175 176 177 178 179
template <typename T, int VecSize, int ShapeSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
    T *dst, const T *__restrict__ src, uint32_t block_offset,
    const kps::details::BroadcastConfig<ShapeSize> &config, int numel, int num,
    bool need_broadcast) {
  // numel : whole num of output
  // num: how many data will be deal with in this time
  if (need_broadcast) {
    kps::ReadDataBc<T, VecSize, 1, 1, ShapeSize, IsBoundary>(
        dst, src, block_offset, config, numel, 1, 1);
  } else {
    kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
180
  }
181
}
182

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
          int VecSize, typename Functor, bool IsBoundary = false>
__device__ void DealSegment(
    const framework::Array<const InT *__restrict__, ET> &in, OutT *out,
    const framework::Array<bool, MAX_INPUT_NUM> &use_broadcast, uint32_t numel,
    const framework::Array<kps::details::BroadcastConfig<ShapeSize>,
                           MAX_INPUT_NUM> &configlists,
    int num, Functor func) {
  InT args[ET][VecSize];
  OutT result[VecSize];
  int block_offset = blockIdx.x * blockDim.x * VecSize;
// load
#pragma unroll
  for (int i = 0; i < ET; i++) {
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
    LoadData<InT, VecSize, ShapeSize, IsBoundary>(args[i], in[i], block_offset,
                                                  configlists[i], numel, num,
                                                  use_broadcast[i]);
201
  }
202 203 204 205 206 207 208 209 210 211
  // compute
  if (ET == kUnary) {
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                             func);
  } else if (ET == kBinary) {
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
                                                              args[1], func);
  } else {
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
212
  }
213 214 215 216
  // compute
  kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
                                                  num);
}
217

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
template <ElementwiseType ET, typename InT, typename OutT, int ShapeSize,
          int VecSize, typename Functor>
__global__ void BroadcastKernel(
    framework::Array<const InT *__restrict__, ET> in, OutT *out,
    framework::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
    framework::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
        configlists,
    int main_tid, int tail_tid, Functor func) {
  int block_offset = blockIdx.x * blockDim.x * VecSize;
  // data offset of this block
  if (blockIdx.x < main_tid) {
    int num = blockDim.x * VecSize;  // blockIdx.x < main_tid
    DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, false>(
        in, out, use_broadcast, numel, configlists, num, func);
  } else {  // reminder
    int num = tail_tid;
    DealSegment<ET, InT, OutT, ShapeSize, VecSize, Functor, true>(
        in, out, use_broadcast, numel, configlists, num, func);
236 237 238
  }
}

239 240 241 242 243 244 245 246 247
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          int Size, typename Functor>
void LaunchKernel(const platform::CUDADeviceContext &ctx,
                  const std::vector<const framework::Tensor *> &ins,
                  framework::Tensor *out, Functor func,
                  DimensionsTransform merge_dims) {
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
248

249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
  int main_tid = numel / (VecSize * threads);
  int tail_tid = numel % (VecSize * threads);
  auto stream = ctx.stream();
  OutT *out_data = out->data<OutT>();

  framework::Array<kps::details::BroadcastConfig<Size>, MAX_INPUT_NUM>
      configlists;
  framework::Array<bool, MAX_INPUT_NUM> use_broadcast;
  framework::Array<const InT *__restrict__, ET> ins_data;

  for (int i = 0; i < ET; i++) {
    use_broadcast[i] = (ins[i]->numel() != numel);
    ins_data[i] = ins[i]->data<InT>();
    if (use_broadcast[i]) {
      // get the broadcast config,
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
      configlists[i] = kps::details::BroadcastConfig<Size>(
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
268 269 270
    }
  }

271 272 273 274
  BroadcastKernel<ET, InT, OutT, Size, VecSize,
                  Functor><<<blocks, threads, 0, stream>>>(
      ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid,
      func);
275 276
}

277 278
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          typename Functor>
279 280 281
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
282
    int axis, Functor func) {
283
  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
284 285 286 287 288
#define DIM_SIZE(size)                                                       \
  case size: {                                                               \
    LaunchKernel<InT, OutT, ET, VecSize, size, Functor>(ctx, ins, out, func, \
                                                        merge_dims);         \
  } break;
289 290

  switch (merge_dims.dim_size) {
291 292 293 294 295 296 297 298
    DIM_SIZE(1);
    DIM_SIZE(2);
    DIM_SIZE(3);
    DIM_SIZE(4);
    DIM_SIZE(5);
    DIM_SIZE(6);
    DIM_SIZE(7);
    DIM_SIZE(8);
299
  }
300
#undef DIM_SIZE
301 302
}

303
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
304 305
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
306 307
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
308 309 310 311 312
  PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
                    platform::errors::InvalidArgument(
                        "Currently, only Support binary calculation, "
                        "but received %d input tensors.\n",
                        static_cast<int>(ET)));
313
  int in_vec_size = 4;
314
  framework::Tensor *out = (*outs)[0];
315
  for (auto *in : ins) {
316
    auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
317 318 319
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
320
  int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
321 322 323 324
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
325 326
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
                                                                 axis, func);
327 328 329
      break;
    }
    case 2: {
330 331
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
                                                                 axis, func);
332 333 334
      break;
    }
    case 1: {
335 336
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
                                                                 axis, func);
337 338 339
      break;
    }
    default: {
340 341
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
342 343 344 345 346
      break;
    }
  }
}

347
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
348
void LaunchElementwiseCudaKernel(
349
    const platform::CUDADeviceContext &cuda_ctx,
350
    const std::vector<const framework::Tensor *> &ins,
351
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
352
  std::vector<int> dims_size;
353 354 355
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
356
    dims_size.emplace_back(in->dims().size());
357
  }
358

359
  if (no_broadcast_flag) {
360 361
    LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                       func);
362
  } else {
363 364 365 366
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
367 368
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                        axis, func);
369 370 371
  }
}

372 373
#undef MAX_INPUT_NUM

374 375
}  // namespace operators
}  // namespace paddle