elementwise_op_broadcast.cu.h 18.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18 19 20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
31 32
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
74
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
143
    MergeDimensions<MergeFunctor>(merge_ptr, N);
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
159
    MergeDimensions<MergeFunctor>(merge_ptr, N);
160 161 162 163
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

164
struct StridesCalculation {
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
  std::vector<std::vector<uint32_t>> strides;
  std::vector<FastDivMod> divmoders;

 private:
  // To calculate the strides of each input_tensor.
  __inline__ void CalculateStrides(
      int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < dim_size; ++i) {
        strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
        strides[j][i] =
            (i != 0 && strides[j][i] != 0)
                ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
                                  std::multiplies<int64_t>())
                : strides[j][i];
      }
    }
  }

 public:
185 186 187
  explicit StridesCalculation(const int64_t &dim_size,
                              const std::vector<std::vector<int64_t>> &in_dims,
                              const std::vector<int64_t> &out_dims) {
188 189 190 191 192 193 194 195 196 197 198
    const auto N = in_dims.size();
    divmoders.resize(dim_size);
    strides.resize(N, std::vector<uint32_t>(dim_size, 1));

    for (int i = 0; i < dim_size; ++i) {
      divmoders[i] = FastDivMod(out_dims[i]);
    }
    CalculateStrides(N, dim_size, in_dims);
  }
};

199 200
template <typename InT, typename OutT, typename Functor, ElementwiseType ET,
          int VecSize, int kDims>
201
struct BroadcastArgsWarpper {
202 203
  using InVecType = CudaAlignedVector<InT, VecSize>;
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
204

205 206 207 208
  OutT *out_data;
  OutVecType *vec_out_data;
  const InT *__restrict__ in_data[ET];
  const InVecType *__restrict__ vec_in_data[ET];
209 210
  bool no_broadcast[ET];
  FastDivMod divmoders[kDims];
211 212 213
  uint32_t strides[ET][framework::DDim::kMaxRank];
  uint32_t scalar_cal_offset;
  Functor func;
214 215

  HOSTDEVICE BroadcastArgsWarpper(
216 217 218 219
      const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
      int scalar_cal_offset, Functor func,
      const StridesCalculation &offset_calculator)
      : scalar_cal_offset(scalar_cal_offset), func(func) {
220
    for (int j = 0; j < ET; ++j) {
221 222
      in_data[j] = ins[j]->data<InT>();
      vec_in_data[j] = reinterpret_cast<const InVecType *>(in_data[j]);
223 224 225 226
      no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
      memcpy(strides[j], offset_calculator.strides[j].data(),
             kDims * sizeof(uint32_t));
    }
227 228
    out_data = out->data<OutT>();
    vec_out_data = reinterpret_cast<OutVecType *>(out_data);
229 230 231 232
    memcpy(divmoders, offset_calculator.divmoders.data(),
           kDims * sizeof(FastDivMod));
  }

233
  __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) {
234 235 236 237 238 239 240 241 242 243 244
    uint32_t offset = 0;

#pragma unroll(kDims)
    for (int i = 0; i < kDims; ++i) {
      auto fast_divmoder = divmoders[i].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += fast_divmoder.val[1] * strides[in_idx][i];
    }
    return offset;
  }

245 246
  __device__ __forceinline__ void LoadVectorizedDataCommon(
      InVecType *vector_args, int tid, int idx) {
247
    *vector_args = vec_in_data[idx][tid];
248 249
  }

250
  __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args,
251
                                                             int tid, int idx) {
252
    int index = tid * VecSize;
253
#pragma unroll(VecSize)
254
    for (int i = 0; i < VecSize; ++i) {
255 256
      uint32_t offset = GetOffsetByDivmod(index + i, idx);
      scalar_args[i] = in_data[idx][offset];
257 258 259
    }
  }

260
  __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid,
261 262
                                                           int idx) {
    args[idx] = in_data[idx][tid + scalar_cal_offset];
263 264
  }

265 266
  __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[],
                                                             int tid, int idx) {
267
    auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx);
268 269 270
    args[idx] = in_data[idx][offset];
  }

271
  __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize],
272
                                                     int tid) {
273 274 275
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
276
        InVecType *vector_args = reinterpret_cast<InVecType *>(args[j]);
277
        LoadVectorizedDataCommon(vector_args, tid, j);
278
      } else {
279
        LoadVectorizedDataByDivmod(args[j], tid, j);
280 281 282 283
      }
    }
  }

284
  __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) {
285 286 287
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
288
        LoadScalarizedDataCommon(args, tid, j);
289
      } else {
290
        LoadScalarizedDataByDivmod(args, tid, j);
291 292 293 294
      }
    }
  }

295
  __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out,
296
                                                      int tid) {
297
    vec_out_data[tid] = vec_args_out;
298 299
  }

300 301
  __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) {
    out_data[scalar_cal_offset + tid] = args_out;
302 303 304
  }
};

305 306
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET>
307
__device__ inline void ScalarizedBroadcastKernelImpl(
308
    BroadcastArgsWarpper broadcast_warpper, int tid) {
309 310
  InT args[ET];
  OutT args_out;
311
  broadcast_warpper.LoadScalarizedData(args, tid);
312 313 314

#pragma unroll(ET)
  for (int j = 1; j < ET; ++j) {
315
    args_out = broadcast_warpper.func(args);
316
  }
317
  broadcast_warpper.StoreScalarizedData(args_out, tid);
318 319
}

320 321
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET, int VecSize>
322
__device__ inline void VectorizedBroadcastKernelImpl(
323
    BroadcastArgsWarpper broadcast_warpper, int tid) {
324 325 326 327
  using OutVecType = CudaAlignedVector<OutT, VecSize>;
  OutVecType args_out;
  InT ins[ET];
  InT args[ET][VecSize];
328
  broadcast_warpper.LoadVectorizedData(args, tid);
329 330

#pragma unroll(VecSize)
331 332 333 334
  for (int i = 0; i < VecSize; ++i) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      ins[j] = args[j][i];
335
    }
336
    args_out.val[i] = broadcast_warpper.func(ins);
337
  }
338
  broadcast_warpper.StoreVectorizedData(args_out, tid);
339 340
}

341 342
template <typename InT, typename OutT, typename BroadcastArgsWarpper,
          ElementwiseType ET, int VecSize>
343 344
__global__ void ElementwiseBroadcastKernel(
    BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) {
345 346
  int tid = threadIdx.x + blockIdx.x * blockDim.x;

347 348 349 350
  // Vectorized calculation of major data whose length is the max multipler of
  // VecSize,
  // eg: Calcualting the front 1024-length data in total 1027 data once VecSize
  // is 4.
351
  if (tid < main_tid) {
352
    VectorizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET, VecSize>(
353
        broadcast_warpper, tid);
354
  }
355 356 357
  // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
  // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
  // 4.
358
  if (tid < tail_tid) {
359
    ScalarizedBroadcastKernelImpl<InT, OutT, BroadcastArgsWarpper, ET>(
360
        broadcast_warpper, tid);
361 362 363
  }
}

364 365
template <typename InT, typename OutT, ElementwiseType ET, int VecSize,
          typename Functor>
366 367 368
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
369
    int axis, Functor func) {
370 371 372 373 374 375 376 377 378
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_tid = numel / VecSize;
  int tail_tid = numel % VecSize;
  int vec_len = main_tid * VecSize;
  auto stream = ctx.stream();

  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
379
  const auto offset_calculator = StridesCalculation(
380 381 382 383
      merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);

  switch (merge_dims.dim_size) {
    case 1: {
384 385 386 387
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 1>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
388
                                 VecSize><<<blocks, threads, 0, stream>>>(
389
          broadcast_warpper, main_tid, tail_tid);
390 391 392
      break;
    }
    case 2: {
393 394 395 396
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 2>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
397
                                 VecSize><<<blocks, threads, 0, stream>>>(
398
          broadcast_warpper, main_tid, tail_tid);
399 400 401
      break;
    }
    case 3: {
402 403 404 405
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 3>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
406
                                 VecSize><<<blocks, threads, 0, stream>>>(
407
          broadcast_warpper, main_tid, tail_tid);
408 409 410
      break;
    }
    case 4: {
411 412 413 414
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 4>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
415
                                 VecSize><<<blocks, threads, 0, stream>>>(
416
          broadcast_warpper, main_tid, tail_tid);
417 418 419
      break;
    }
    case 5: {
420 421 422 423
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 5>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
424
                                 VecSize><<<blocks, threads, 0, stream>>>(
425
          broadcast_warpper, main_tid, tail_tid);
426 427 428
      break;
    }
    case 6: {
429 430 431 432
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 6>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
433
                                 VecSize><<<blocks, threads, 0, stream>>>(
434
          broadcast_warpper, main_tid, tail_tid);
435 436 437
      break;
    }
    case 7: {
438 439 440 441
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 7>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
442
                                 VecSize><<<blocks, threads, 0, stream>>>(
443
          broadcast_warpper, main_tid, tail_tid);
444 445 446
      break;
    }
    case 8: {
447 448 449 450
      auto broadcast_warpper =
          BroadcastArgsWarpper<InT, OutT, Functor, ET, VecSize, 8>(
              ins, out, vec_len, func, offset_calculator);
      ElementwiseBroadcastKernel<InT, OutT, decltype(broadcast_warpper), ET,
451
                                 VecSize><<<blocks, threads, 0, stream>>>(
452
          broadcast_warpper, main_tid, tail_tid);
453 454 455 456 457 458 459 460 461 462 463
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
  }
}

464
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
465 466
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
467 468 469
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
  static_assert(ET == (ElementwiseType)2, "Only Support binary calculation.");
470
  int in_vec_size = 4;
471
  framework::Tensor *out = (*outs)[0];
472
  for (auto *in : ins) {
473
    auto temp_size = GetVectorizedSizeImpl<InT>(in->data<InT>());
474 475 476
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
477
  int out_vec_size = GetVectorizedSizeImpl<OutT>(out->data<OutT>());
478 479 480 481
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
482 483
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 4>(ctx, ins, out,
                                                                 axis, func);
484 485 486
      break;
    }
    case 2: {
487 488
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 2>(ctx, ins, out,
                                                                 axis, func);
489 490 491
      break;
    }
    case 1: {
492 493
      LaunchBroadcastKernelForDifferentDimSize<InT, OutT, ET, 1>(ctx, ins, out,
                                                                 axis, func);
494 495 496
      break;
    }
    default: {
497 498
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
499 500 501 502 503
      break;
    }
  }
}

504 505 506 507 508 509 510 511 512 513 514 515 516 517
template <ElementwiseType ET, typename InT, typename OutType, typename Functor>
void LaunchElementwiseCudaKernel(
    const platform::CUDADeviceContext &cuda_ctx,
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
  }

  if (no_broadcast_flag) {
    LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, InT, OutType>(
        cuda_ctx, ins, outs, func);
  } else {
518 519 520
    LaunchBroadcastElementwiseCudaKernel<ElementwiseType::kBinary, InT,
                                         OutType>(cuda_ctx, ins, outs, axis,
                                                  func);
521 522 523
  }
}

524 525
}  // namespace operators
}  // namespace paddle