elementwise_op_broadcast.cu.h 19.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
31 32
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
74
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
143
    MergeDimensions<MergeFunctor>(merge_ptr, N);
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
159
    MergeDimensions<MergeFunctor>(merge_ptr, N);
160 161 162 163
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

164
struct StridesCalculation {
165
  std::vector<std::vector<uint32_t>> strides;
166
  std::vector<platform::FastDivMod> divmoders;
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184

 private:
  // To calculate the strides of each input_tensor.
  __inline__ void CalculateStrides(
      int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < dim_size; ++i) {
        strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
        strides[j][i] =
            (i != 0 && strides[j][i] != 0)
                ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
                                  std::multiplies<int64_t>())
                : strides[j][i];
      }
    }
  }

 public:
185 186 187
  explicit StridesCalculation(const int64_t &dim_size,
                              const std::vector<std::vector<int64_t>> &in_dims,
                              const std::vector<int64_t> &out_dims) {
188 189 190 191 192
    const auto N = in_dims.size();
    divmoders.resize(dim_size);
    strides.resize(N, std::vector<uint32_t>(dim_size, 1));

    for (int i = 0; i < dim_size; ++i) {
193
      divmoders[i] = platform::FastDivMod(out_dims[i]);
194 195 196 197 198
    }
    CalculateStrides(N, dim_size, in_dims);
  }
};

199 200
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
          int VecSize, int kDims>
201 202 203
struct BroadcastArgsWrapper {
  using InVecType = platform::CudaAlignedVector<InT, VecSize>;
  using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
204

205 206 207 208
  OutT *out_data;
  OutVecType *vec_out_data;
  const InT *__restrict__ in_data[ET];
  const InVecType *__restrict__ vec_in_data[ET];
209
  bool no_broadcast[ET];
210
  platform::FastDivMod divmoders[kDims];
211 212 213
  uint32_t strides[ET][framework::DDim::kMaxRank];
  uint32_t scalar_cal_offset;
  Functor func;
214

215
  HOSTDEVICE BroadcastArgsWrapper(
216 217 218 219
      const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
      int scalar_cal_offset, Functor func,
      const StridesCalculation &offset_calculator)
      : scalar_cal_offset(scalar_cal_offset), func(func) {
220
    for (int j = 0; j < ET; ++j) {
221 222
      in_data[j] = ins[j]->data<InT>();
      vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
223 224 225 226
      no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
      memcpy(strides[j], offset_calculator.strides[j].data(),
             kDims * sizeof(uint32_t));
    }
227 228
    out_data = out->data<OutT>();
    vec_out_data = reinterpret_cast<OutVecType *>(out_data);
229
    memcpy(divmoders, offset_calculator.divmoders.data(),
230
           kDims * sizeof(platform::FastDivMod));
231 232
  }

233
  __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
234 235 236 237 238 239 240 241 242 243 244
    uint32_t offset = 0;

#pragma unroll(kDims)
    for (int i = 0; i < kDims; ++i) {
      auto fast_divmoder = divmoders[i].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += fast_divmoder.val[1] * strides[in_idx][i];
    }
    return offset;
  }

245 246
  __device__ __forceinline__ void LoadVectorizedDataCommon(
      InVecType *vector_args, int tid, int idx) {
247
    *vector_args = vec_in_data[idx][tid];
248 249
  }

250
  __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
251
                                                             int tid, int idx) {
252
    int index = tid * VecSize;
253
#pragma unroll(VecSize)
254
    for (int i = 0; i < VecSize; ++i) {
255 256
      uint32_t offset = GetOffsetByDivmod(index + i, idx);
      scalar_args[i] = in_data[idx][offset];
257 258 259
    }
  }

260
  __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
261 262
                                                           int idx) {
    args[idx] = in_data[idx][tid + scalar_cal_offset];
263 264
  }

265 266
  __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
                                                             int tid, int idx) {
267
    auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
268 269 270
    args[idx] = in_data[idx][offset];
  }

271
  __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
272
                                                     int tid) {
273 274 275
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
276
        InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
277
        LoadVectorizedDataCommon(vector_args, tid, j);
278
      } else {
279
        LoadVectorizedDataByDivmod(args[j], tid, j);
280 281 282 283
      }
    }
  }

284
  __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
285 286 287
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
288
        LoadScalarizedDataCommon(args, tid, j);
289
      } else {
290
        LoadScalarizedDataByDivmod(args, tid, j);
291 292 293 294
      }
    }
  }

295
  __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
296
                                                      int tid) {
297
    vec_out_data[tid] = vec_args_out;
298 299
  }

300 301
  __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
    out_data[scalar_cal_offset + tid] = args_out;
302 303 304
  }
};

305
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
306
          ElementwiseType ET>
307
__device__ inline void ScalarizedBroadcastKernelImpl(
308
    BroadcastArgsWrapper broadcast_wrapper, int tid) {
309 310
  InT args[ET];
  OutT args_out;
311
  broadcast_wrapper.LoadScalarizedData(args, tid);
312

313 314 315 316
  // Calcualtion of the in_tensor data.
  args_out = broadcast_wrapper.func(args);

  broadcast_wrapper.StoreScalarizedData(args_out, tid);
317 318
}

319
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
320
          ElementwiseType ET, int VecSize>
321
__device__ inline void VectorizedBroadcastKernelImpl(
322 323
    BroadcastArgsWrapper broadcast_wrapper, int tid) {
  using OutVecType = platform::CudaAlignedVector<OutT, VecSize>;
324 325 326
  OutVecType args_out;
  InT ins[ET];
  InT args[ET][VecSize];
327
  broadcast_wrapper.LoadVectorizedData(args, tid);
328 329

#pragma unroll(VecSize)
330 331 332 333
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      ins[j] = args[j][i];
334
    }
335
    args_out.val[i] = broadcast_wrapper.func(ins);
336
  }
337
  broadcast_wrapper.StoreVectorizedData(args_out, tid);
338 339
}

340
template <typename InT, typename OutT, typename BroadcastArgsWrapper,
341
          ElementwiseType ET, int VecSize>
342
__global__ void ElementwiseBroadcastKernel(
343
    BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) {
344
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
345 346 347 348
  // Vectorized calculation of major data whose length is the max multipler of
  // VecSize,
  // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
  // is 4.
349
  if (tid < main_tid) {
350 351
    VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET, VecSize>(
        broadcast_wrapper, tid);
352
  }
353 354 355
  // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
  // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
  // 4.
356
  if (tid < tail_tid) {
357 358
    ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWrapper, ET>(
        broadcast_wrapper, tid);
359 360 361
  }
}

362 363
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          typename Functor>
364 365 366
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
367
    int axis, Functor func) {
368
  int numel = out->numel();
369
  int threads = GetThreadsConfig(ctx, numel, VecSize);
370 371 372 373 374 375 376
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_tid = numel / VecSize;
  int tail_tid = numel % VecSize;
  int vec_len = main_tid * VecSize;
  auto stream = ctx.stream();

  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
377
  const auto offset_calculator = StridesCalculation(
378 379 380 381
      merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);

  switch (merge_dims.dim_size) {
    case 1: {
382 383
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 1>(
384
              ins, out, vec_len, func, offset_calculator);
385
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
386
                                 VecSize><<<blocks, threads, 0, stream>>>(
387
          broadcast_wrapper, main_tid, tail_tid);
388 389 390
      break;
    }
    case 2: {
391 392
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 2>(
393
              ins, out, vec_len, func, offset_calculator);
394
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
395
                                 VecSize><<<blocks, threads, 0, stream>>>(
396
          broadcast_wrapper, main_tid, tail_tid);
397 398 399
      break;
    }
    case 3: {
400 401
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 3>(
402
              ins, out, vec_len, func, offset_calculator);
403
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
404
                                 VecSize><<<blocks, threads, 0, stream>>>(
405
          broadcast_wrapper, main_tid, tail_tid);
406 407 408
      break;
    }
    case 4: {
409 410
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 4>(
411
              ins, out, vec_len, func, offset_calculator);
412
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
413
                                 VecSize><<<blocks, threads, 0, stream>>>(
414
          broadcast_wrapper, main_tid, tail_tid);
415 416 417
      break;
    }
    case 5: {
418 419
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 5>(
420
              ins, out, vec_len, func, offset_calculator);
421
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
422
                                 VecSize><<<blocks, threads, 0, stream>>>(
423
          broadcast_wrapper, main_tid, tail_tid);
424 425 426
      break;
    }
    case 6: {
427 428
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 6>(
429
              ins, out, vec_len, func, offset_calculator);
430
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
431
                                 VecSize><<<blocks, threads, 0, stream>>>(
432
          broadcast_wrapper, main_tid, tail_tid);
433 434 435
      break;
    }
    case 7: {
436 437
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 7>(
438
              ins, out, vec_len, func, offset_calculator);
439
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
440
                                 VecSize><<<blocks, threads, 0, stream>>>(
441
          broadcast_wrapper, main_tid, tail_tid);
442 443 444
      break;
    }
    case 8: {
445 446
      auto broadcast_wrapper =
          BroadcastArgsWrapper<InT, OutT, Functor, ET, VecSize, 8>(
447
              ins, out, vec_len, func, offset_calculator);
448
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_wrapper), ET,
449
                                 VecSize><<<blocks, threads, 0, stream>>>(
450
          broadcast_wrapper, main_tid, tail_tid);
451 452 453 454 455 456 457 458 459 460 461
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
  }
}

462
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
463 464
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
465 466
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
467 468 469 470 471
  PADDLE_ENFORCE_EQ(ET, ElementwiseType::kBinary,
                    platform::errors::InvalidArgument(
                        "Currently, only Support binary calculation, "
                        "but received %d input tensors.\n",
                        static_cast<int>(ET)));
472
  int in_vec_size = 4;
473
  framework::Tensor *out = (*outs)[0];
474
  for (auto *in : ins) {
475
    auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
476 477 478
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
479
  int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
480 481 482 483
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
484 485
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
                                                                 axis, func);
486 487 488
      break;
    }
    case 2: {
489 490
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
                                                                 axis, func);
491 492 493
      break;
    }
    case 1: {
494 495
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
                                                                 axis, func);
496 497 498
      break;
    }
    default: {
499 500
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
501 502 503 504 505
      break;
    }
  }
}

506
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
507
void LaunchElementwiseCudaKernel(
508
    const platform::CUDADeviceContext &cuda_ctx,
509
    const std::vector<const framework::Tensor *> &ins,
510
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
511
  std::vector<int> dims_size;
512 513 514
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
515
    dims_size.emplace_back(in->dims().size());
516
  }
517

518
  if (no_broadcast_flag) {
519 520
    LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                       func);
521
  } else {
522 523 524 525
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
526 527
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                        axis, func);
528 529 530
  }
}

531 532
}  // namespace operators
}  // namespace paddle