elementwise_op_broadcast.cu.h 19.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
31 32
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
74
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
143
    MergeDimensions<MergeFunctor>(merge_ptr, N);
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
159
    MergeDimensions<MergeFunctor>(merge_ptr, N);
160 161 162 163
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

164
struct StridesCalculation {
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
  std::vector<std::vector<uint32_t>> strides;
  std::vector<FastDivMod> divmoders;

 private:
  // To calculate the strides of each input_tensor.
  __inline__ void CalculateStrides(
      int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < dim_size; ++i) {
        strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
        strides[j][i] =
            (i != 0 && strides[j][i] != 0)
                ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
                                  std::multiplies<int64_t>())
                : strides[j][i];
      }
    }
  }

 public:
185 186 187
  explicit StridesCalculation(const int64_t &dim_size,
                              const std::vector<std::vector<int64_t>> &in_dims,
                              const std::vector<int64_t> &out_dims) {
188 189 190 191 192 193 194 195 196 197 198
    const auto N = in_dims.size();
    divmoders.resize(dim_size);
    strides.resize(N, std::vector<uint32_t>(dim_size, 1));

    for (int i = 0; i < dim_size; ++i) {
      divmoders[i] = FastDivMod(out_dims[i]);
    }
    CalculateStrides(N, dim_size, in_dims);
  }
};

199 200
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
          int VecSize, int kDims>
201
struct BroadcastArgsWarpper {
202 203
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
204

205 206 207 208
  OutT *out_data;
  OutVecType *vec_out_data;
  const InT *__restrict__ in_data[ET];
  const InVecType *__restrict__ vec_in_data[ET];
209 210
  bool no_broadcast[ET];
  FastDivMod divmoders[kDims];
211 212 213
  uint32_t strides[ET][framework::DDim::kMaxRank];
  uint32_t scalar_cal_offset;
  Functor func;
214 215

  HOSTDEVICE BroadcastArgsWarpper(
216 217 218 219
      const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
      int scalar_cal_offset, Functor func,
      const StridesCalculation &offset_calculator)
      : scalar_cal_offset(scalar_cal_offset), func(func) {
220
    for (int j = 0; j < ET; ++j) {
221 222
      in_data[j] = ins[j]->data<InT>();
      vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
223 224 225 226
      no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
      memcpy(strides[j], offset_calculator.strides[j].data(),
             kDims * sizeof(uint32_t));
    }
227 228
    out_data = out->data<OutT>();
    vec_out_data = reinterpret_cast<OutVecType *>(out_data);
229 230 231 232
    memcpy(divmoders, offset_calculator.divmoders.data(),
           kDims * sizeof(FastDivMod));
  }

233
  __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
234 235 236 237 238 239 240 241 242 243 244
    uint32_t offset = 0;

#pragma unroll(kDims)
    for (int i = 0; i < kDims; ++i) {
      auto fast_divmoder = divmoders[i].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += fast_divmoder.val[1] * strides[in_idx][i];
    }
    return offset;
  }

245 246
  __device__ __forceinline__ void LoadVectorizedDataCommon(
      InVecType *vector_args, int tid, int idx) {
247
    *vector_args = vec_in_data[idx][tid];
248 249
  }

250
  __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
251
                                                             int tid, int idx) {
252
    int index = tid * VecSize;
253
#pragma unroll(VecSize)
254
    for (int i = 0; i < VecSize; ++i) {
255 256
      uint32_t offset = GetOffsetByDivmod(index + i, idx);
      scalar_args[i] = in_data[idx][offset];
257 258 259
    }
  }

260
  __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
261 262
                                                           int idx) {
    args[idx] = in_data[idx][tid + scalar_cal_offset];
263 264
  }

265 266
  __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
                                                             int tid, int idx) {
267
    auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
268 269 270
    args[idx] = in_data[idx][offset];
  }

271
  __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
272
                                                     int tid) {
273 274 275
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
276
        InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
277
        LoadVectorizedDataCommon(vector_args, tid, j);
278
      } else {
279
        LoadVectorizedDataByDivmod(args[j], tid, j);
280 281 282 283
      }
    }
  }

284
  __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
285 286 287
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
288
        LoadScalarizedDataCommon(args, tid, j);
289
      } else {
290
        LoadScalarizedDataByDivmod(args, tid, j);
291 292 293 294
      }
    }
  }

295
  __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
296
                                                      int tid) {
297
    vec_out_data[tid] = vec_args_out;
298 299
  }

300 301
  __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
    out_data[scalar_cal_offset + tid] = args_out;
302 303 304
  }
};

305 306
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET>
307
__device__ inline void ScalarizedBroadcastKernelImpl(
308
    BroadcastArgsWarpper broadcast_warpper, int tid) {
309 310
  InT args[ET];
  OutT args_out;
311
  broadcast_warpper.LoadScalarizedData(args, tid);
312 313 314

#pragma unroll(ET)
  for (int j = 1; j < ET; ++j) {
315
    args_out = broadcast_warpper.func(args);
316
  }
317
  broadcast_warpper.StoreScalarizedData(args_out, tid);
318 319
}

320 321
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET, int VecSize>
322
__device__ inline void VectorizedBroadcastKernelImpl(
323
    BroadcastArgsWarpper broadcast_warpper, int tid) {
324 325 326 327
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
  OutVecType args_out;
  InT ins[ET];
  InT args[ET][VecSize];
328
  broadcast_warpper.LoadVectorizedData(args, tid);
329 330

#pragma unroll(VecSize)
331 332 333 334
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      ins[j] = args[j][i];
335
    }
336
    args_out.val[i] = broadcast_warpper.func(ins);
337
  }
338
  broadcast_warpper.StoreVectorizedData(args_out, tid);
339 340
}

341 342
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET, int VecSize>
343 344
__global__ void ElementwiseBroadcastKernel(
    BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
345
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
346 347 348 349
  // Vectorized calculation of major data whose length is the max multipler of
  // VecSize,
  // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
  // is 4.
350
  if (tid < main_tid) {
351
    VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
352
        broadcast_warpper, tid);
353
  }
354 355 356
  // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
  // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
  // 4.
357
  if (tid < tail_tid) {
358
    ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
359
        broadcast_warpper, tid);
360 361 362
  }
}

363 364
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          typename Functor>
365 366 367
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
368
    int axis, Functor func) {
369 370 371 372 373 374 375 376 377
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_tid = numel / VecSize;
  int tail_tid = numel % VecSize;
  int vec_len = main_tid * VecSize;
  auto stream = ctx.stream();

  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
378
  const auto offset_calculator = StridesCalculation(
379 380 381 382
      merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);

  switch (merge_dims.dim_size) {
    case 1: {
383 384 385 386
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
387
                                 VecSize><<<blocks, threads, 0, stream>>>(
388
          broadcast_warpper, main_tid, tail_tid);
389 390 391
      break;
    }
    case 2: {
392 393 394 395
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
396
                                 VecSize><<<blocks, threads, 0, stream>>>(
397
          broadcast_warpper, main_tid, tail_tid);
398 399 400
      break;
    }
    case 3: {
401 402 403 404
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
405
                                 VecSize><<<blocks, threads, 0, stream>>>(
406
          broadcast_warpper, main_tid, tail_tid);
407 408 409
      break;
    }
    case 4: {
410 411 412 413
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
414
                                 VecSize><<<blocks, threads, 0, stream>>>(
415
          broadcast_warpper, main_tid, tail_tid);
416 417 418
      break;
    }
    case 5: {
419 420 421 422
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
423
                                 VecSize><<<blocks, threads, 0, stream>>>(
424
          broadcast_warpper, main_tid, tail_tid);
425 426 427
      break;
    }
    case 6: {
428 429 430 431
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
432
                                 VecSize><<<blocks, threads, 0, stream>>>(
433
          broadcast_warpper, main_tid, tail_tid);
434 435 436
      break;
    }
    case 7: {
437 438 439 440
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
441
                                 VecSize><<<blocks, threads, 0, stream>>>(
442
          broadcast_warpper, main_tid, tail_tid);
443 444 445
      break;
    }
    case 8: {
446 447 448 449
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
450
                                 VecSize><<<blocks, threads, 0, stream>>>(
451
          broadcast_warpper, main_tid, tail_tid);
452 453 454 455 456 457 458 459 460 461 462
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
  }
}

463
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
464 465
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
466 467 468
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
  static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
469
  int in_vec_size = 4;
470
  framework::Tensor *out = (*outs)[0];
471
  for (auto *in : ins) {
472
    auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
473 474 475
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
476
  int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
477 478 479 480
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
481 482
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
                                                                 axis, func);
483 484 485
      break;
    }
    case 2: {
486 487
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
                                                                 axis, func);
488 489 490
      break;
    }
    case 1: {
491 492
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
                                                                 axis, func);
493 494 495
      break;
    }
    default: {
496 497
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
498 499 500 501 502
      break;
    }
  }
}

503
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
504
void LaunchElementwiseCudaKernel(
505
    const framework::ExecutionContext &ctx,
506
    const std::vector<const framework::Tensor *> &ins,
507 508
    std::vector<framework::Tensor *> *outs, Functor func) {
  std::vector<int> dims_size;
509 510 511
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
512
    dims_size.emplace_back(in->dims().size());
513
  }
514 515
  const auto &cuda_ctx =
      ctx.template device_context<platform::CUDADeviceContext>();
516
  if (no_broadcast_flag) {
517
    LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutT>(
518 519
        cuda_ctx, ins, outs, func);
  } else {
520 521 522 523 524 525 526
    int axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                        axis, func);
527 528 529
  }
}

530 531
}  // namespace operators
}  // namespace paddle