elementwise_op_broadcast.cu.h 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
31 32
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
74
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
143
    MergeDimensions<MergeFunctor>(merge_ptr, N);
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
159
    MergeDimensions<MergeFunctor>(merge_ptr, N);
160 161 162 163
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

164
struct StridesCalculation {
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
  std::vector<std::vector<uint32_t>> strides;
  std::vector<FastDivMod> divmoders;

 private:
  // To calculate the strides of each input_tensor.
  __inline__ void CalculateStrides(
      int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < dim_size; ++i) {
        strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
        strides[j][i] =
            (i != 0 && strides[j][i] != 0)
                ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
                                  std::multiplies<int64_t>())
                : strides[j][i];
      }
    }
  }

 public:
185 186 187
  explicit StridesCalculation(const int64_t &dim_size,
                              const std::vector<std::vector<int64_t>> &in_dims,
                              const std::vector<int64_t> &out_dims) {
188 189 190 191 192 193 194 195 196 197 198
    const auto N = in_dims.size();
    divmoders.resize(dim_size);
    strides.resize(N, std::vector<uint32_t>(dim_size, 1));

    for (int i = 0; i < dim_size; ++i) {
      divmoders[i] = FastDivMod(out_dims[i]);
    }
    CalculateStrides(N, dim_size, in_dims);
  }
};

199 200
template <typename T, typename Functor, ElementwiseType ET, int VecSize,
          int kDims>
201
struct BroadcastArgsWarpper {
202
  using VecType = CudaAlignedVector<T, VecSize>;
203 204

  T *out_data;
205
  VecType *vec_out_data;
206
  const T *__restrict__ in_data[ET];
207
  const VecType *__restrict__ vec_in_data[ET];
208 209
  bool no_broadcast[ET];
  FastDivMod divmoders[kDims];
210 211 212
  uint32_t strides[ET][framework::DDim::kMaxRank];
  uint32_t scalar_cal_offset;
  Functor func;
213 214

  HOSTDEVICE BroadcastArgsWarpper(
215 216 217 218
      const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
      int scalar_cal_offset, Functor func,
      const StridesCalculation &offset_calculator)
      : scalar_cal_offset(scalar_cal_offset), func(func) {
219 220
    for (int j = 0; j < ET; ++j) {
      in_data[j] = ins[j]->data<T>();
221
      vec_in_data[j] = reinterpret_cast<const VecType *>(in_data[j]);
222 223 224 225 226
      no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
      memcpy(strides[j], offset_calculator.strides[j].data(),
             kDims * sizeof(uint32_t));
    }
    out_data = out->data<T>();
227
    vec_out_data = reinterpret_cast<VecType *>(out_data);
228 229 230 231
    memcpy(divmoders, offset_calculator.divmoders.data(),
           kDims * sizeof(FastDivMod));
  }

232
  __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
233 234 235 236 237 238 239 240 241 242 243
    uint32_t offset = 0;

#pragma unroll(kDims)
    for (int i = 0; i < kDims; ++i) {
      auto fast_divmoder = divmoders[i].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += fast_divmoder.val[1] * strides[in_idx][i];
    }
    return offset;
  }

244 245 246
  __device__ __forceinline__ void LoadVectorizedDataCommon(VecType *vector_args,
                                                           int tid, int idx) {
    *vector_args = vec_in_data[idx][tid];
247 248
  }

249 250
  __device__ __forceinline__ void LoadVectorizedDataByDivmod(T *scalar_args,
                                                             int tid, int idx) {
251
    int index = tid * VecSize;
252
#pragma unroll(VecSize)
253
    for (int i = 0; i < VecSize; ++i) {
254 255
      uint32_t offset = GetOffsetByDivmod(index + i, idx);
      scalar_args[i] = in_data[idx][offset];
256 257 258
    }
  }

259 260 261
  __device__ __forceinline__ void LoadScalarizedDataCommon(T args[], int tid,
                                                           int idx) {
    args[idx] = in_data[idx][tid + scalar_cal_offset];
262 263
  }

264 265 266
  __device__ __forceinline__ void LoadScalarizedDataByDivmod(T args[], int tid,
                                                             int idx) {
    auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
267 268 269
    args[idx] = in_data[idx][offset];
  }

270 271
  __device__ __forceinline__ void LoadVectorizedData(T (*args)[VecSize],
                                                     int tid) {
272 273 274
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
275 276
        VecType *vector_args = reinterpret_cast<VecType *>(args[j]);
        LoadVectorizedDataCommon(vector_args, tid, j);
277
      } else {
278
        LoadVectorizedDataByDivmod(args[j], tid, j);
279 280 281 282
      }
    }
  }

283
  __device__ __forceinline__ void LoadScalarizedData(T args[], int tid) {
284 285 286
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
287
        LoadScalarizedDataCommon(args, tid, j);
288
      } else {
289
        LoadScalarizedDataByDivmod(args, tid, j);
290 291 292 293
      }
    }
  }

294 295 296 297
  __device__ __forceinline__ void StoreVectorizedData(T (*args)[VecSize],
                                                      int tid) {
    VecType *args_out = reinterpret_cast<VecType *>(args[0]);
    vec_out_data[tid] = *args_out;
298 299
  }

300 301
  __device__ __forceinline__ void StoreScalarizedData(T args[], int tid) {
    out_data[scalar_cal_offset + tid] = args[0];
302 303 304 305 306
  }
};

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
307
    BroadcastArgsWarpper broadcast_warpper, int tid) {
308
  T args[ET];
309
  broadcast_warpper.LoadScalarizedData(args, tid);
310 311 312

#pragma unroll(ET)
  for (int j = 1; j < ET; ++j) {
313
    args[0] = broadcast_warpper.func(args);
314
  }
315
  broadcast_warpper.StoreScalarizedData(args, tid);
316 317 318 319 320
}

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
          int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
321 322 323 324
    BroadcastArgsWarpper broadcast_warpper, int tid) {
  T ins[ET];
  T args[ET][VecSize];
  broadcast_warpper.LoadVectorizedData(args, tid);
325 326

#pragma unroll(VecSize)
327 328 329 330
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      ins[j] = args[j][i];
331
    }
332
    args[0][i] = broadcast_warpper.func(ins);
333
  }
334
  broadcast_warpper.StoreVectorizedData(args, tid);
335 336 337 338
}

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
          int VecSize>
339 340
__global__ void ElementwiseBroadcastKernel(
    BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
341 342
  int tid = threadIdx.x + blockIdx.x * blockDim.x;

343 344 345 346
  // Vectorized calculation of major data whose length is the max multipler of
  // VecSize,
  // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
  // is 4.
347 348
  if (tid < main_tid) {
    VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>(
349
        broadcast_warpper, tid);
350
  }
351 352 353
  // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
  // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
  // 4.
354
  if (tid < tail_tid) {
355 356
    ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>(
        broadcast_warpper, tid);
357 358 359
  }
}

360
template <typename T, ElementwiseType ET, int VecSize, typename Functor>
361 362 363
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
364
    int axis, Functor func) {
365 366 367 368 369 370 371 372 373
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_tid = numel / VecSize;
  int tail_tid = numel % VecSize;
  int vec_len = main_tid * VecSize;
  auto stream = ctx.stream();

  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
374
  const auto offset_calculator = StridesCalculation(
375 376 377 378
      merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);

  switch (merge_dims.dim_size) {
    case 1: {
379 380 381
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 1>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
382
                                 VecSize><<<blocks, threads, 0, stream>>>(
383
          broadcast_warpper, main_tid, tail_tid);
384 385 386
      break;
    }
    case 2: {
387 388 389
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 2>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
390
                                 VecSize><<<blocks, threads, 0, stream>>>(
391
          broadcast_warpper, main_tid, tail_tid);
392 393 394
      break;
    }
    case 3: {
395 396 397
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 3>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
398
                                 VecSize><<<blocks, threads, 0, stream>>>(
399
          broadcast_warpper, main_tid, tail_tid);
400 401 402
      break;
    }
    case 4: {
403 404 405
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 4>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
406
                                 VecSize><<<blocks, threads, 0, stream>>>(
407
          broadcast_warpper, main_tid, tail_tid);
408 409 410
      break;
    }
    case 5: {
411 412 413
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 5>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
414
                                 VecSize><<<blocks, threads, 0, stream>>>(
415
          broadcast_warpper, main_tid, tail_tid);
416 417 418
      break;
    }
    case 6: {
419 420 421
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 6>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
422
                                 VecSize><<<blocks, threads, 0, stream>>>(
423
          broadcast_warpper, main_tid, tail_tid);
424 425 426
      break;
    }
    case 7: {
427 428 429
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 7>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
430
                                 VecSize><<<blocks, threads, 0, stream>>>(
431
          broadcast_warpper, main_tid, tail_tid);
432 433 434
      break;
    }
    case 8: {
435 436 437
      auto broadcast_warpper = BroadcastArgsWarpper<T, Functor, ET, VecSize, 8>(
          ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<T, decltype(broadcast_warpper), ET,
438
                                 VecSize><<<blocks, threads, 0, stream>>>(
439
          broadcast_warpper, main_tid, tail_tid);
440 441 442 443 444 445 446 447 448 449 450 451 452 453
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
  }
}

template <ElementwiseType ET, typename T, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
454 455 456
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
  static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
457
  int in_vec_size = 4;
458
  framework::Tensor *out = (*outs)[0];
459 460 461 462 463 464 465 466 467 468
  for (auto *in : ins) {
    auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>());
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
  int out_vec_size = GetVectorizedSizeImpl<T>(out->data<T>());
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
469 470
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis,
                                                         func);
471 472 473
      break;
    }
    case 2: {
474 475 476 477 478 479 480
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis,
                                                         func);
      break;
    }
    case 1: {
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis,
                                                         func);
481 482 483
      break;
    }
    default: {
484 485
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
486 487 488 489 490
      break;
    }
  }
}

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
template <ElementwiseType ET, typename InT, typename OutType, typename Functor>
void LaunchElementwiseCudaKernel(
    const platform::CUDADeviceContext &cuda_ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
  }

  if (no_broadcast_flag) {
    LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
        cuda_ctx, ins, outs, func);
  } else {
    LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT>(
        cuda_ctx, ins, outs, axis, func);
  }
}

510 511
}  // namespace operators
}  // namespace paddle