elementwise_op_broadcast.cu.h 15.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h"

namespace paddle {
namespace operators {

struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
  // 1. To compensate the lackage of input_tensors` dimension;
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
                "The %dth dimension of input tensor is expected to be equal "
                "with"
                "the %dth dimension of output tensor %d or 1, but recieved "
                "%d.\n",
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
  __inline__ void DimensionsReorganise(MergeFunctor merge_func, int N) {
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
    DimensionsReorganise<MergeFunctor>(merge_ptr, N);

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
    DimensionsReorganise<MergeFunctor>(merge_ptr, N);
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

struct CalculateInputStrides {
  std::vector<std::vector<uint32_t>> strides;
  std::vector<FastDivMod> divmoders;

 private:
  // To calculate the strides of each input_tensor.
  __inline__ void CalculateStrides(
      int N, int dim_size, const std::vector<std::vector<int64_t>> &in_dims) {
    for (int j = 0; j < N; ++j) {
      for (int i = 0; i < dim_size; ++i) {
        strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i];
        strides[j][i] =
            (i != 0 && strides[j][i] != 0)
                ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1,
                                  std::multiplies<int64_t>())
                : strides[j][i];
      }
    }
  }

 public:
  explicit CalculateInputStrides(
      const int64_t &dim_size, const std::vector<std::vector<int64_t>> &in_dims,
      const std::vector<int64_t> &out_dims) {
    const auto N = in_dims.size();
    divmoders.resize(dim_size);
    strides.resize(N, std::vector<uint32_t>(dim_size, 1));

    for (int i = 0; i < dim_size; ++i) {
      divmoders[i] = FastDivMod(out_dims[i]);
    }
    CalculateStrides(N, dim_size, in_dims);
  }
};

template <typename T, ElementwiseType ET, int VecSize, int kDims>
struct BroadcastArgsWarpper {
  using DimsVec = CudaAlignedVector<T, VecSize>;

  T *out_data;
  const T *__restrict__ in_data[ET];
  uint32_t strides[ET][framework::DDim::kMaxRank];
  bool no_broadcast[ET];
  FastDivMod divmoders[kDims];
  uint32_t scalar_offset;

  HOSTDEVICE BroadcastArgsWarpper(
      const std::vector<const framework::Tensor *> &ins,
      const CalculateInputStrides &offset_calculator, framework::Tensor *out,
      int scalar_offset)
      : scalar_offset(scalar_offset) {
    for (int j = 0; j < ET; ++j) {
      in_data[j] = ins[j]->data<T>();
      no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false;
      memcpy(strides[j], offset_calculator.strides[j].data(),
             kDims * sizeof(uint32_t));
    }
    out_data = out->data<T>();
    memcpy(divmoders, offset_calculator.divmoders.data(),
           kDims * sizeof(FastDivMod));
  }

  __device__ __forceinline__ uint32_t GetDivmodOffset(int idx, int in_idx) {
    uint32_t offset = 0;

#pragma unroll(kDims)
    for (int i = 0; i < kDims; ++i) {
      auto fast_divmoder = divmoders[i].Divmod(idx);
      idx = fast_divmoder.val[0];
      offset += fast_divmoder.val[1] * strides[in_idx][i];
    }
    return offset;
  }

  __device__ __forceinline__ void CommonVector(DimsVec args[], int tid,
                                               int idx) {
    const DimsVec *__restrict__ vec_data =
        reinterpret_cast<const DimsVec *__restrict__>(in_data[idx]);
    args[idx] = vec_data[tid];
  }

  __device__ __forceinline__ void DivmodVector(DimsVec args[], int tid,
                                               int idx) {
    int index = tid * VecSize;

    for (int i = 0; i < VecSize; ++i) {
      uint32_t offset = GetDivmodOffset(index + i, idx);
      args[idx].val[i] = in_data[idx][offset];
    }
  }

  __device__ __forceinline__ void CommonScalar(T args[], int tid, int idx) {
    args[idx] = in_data[idx][tid + scalar_offset];
  }

  __device__ __forceinline__ void DivmodScalar(T args[], int tid, int idx) {
    auto offset = GetDivmodOffset(tid + scalar_offset, idx);
    args[idx] = in_data[idx][offset];
  }

  __device__ __forceinline__ void LoadVector(DimsVec args[], int tid) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
        CommonVector(args, tid, j);
      } else {
        DivmodVector(args, tid, j);
      }
    }
  }

  __device__ __forceinline__ void LoadScalar(T args[], int tid) {
#pragma unroll(ET)
    for (int j = 0; j < ET; ++j) {
      if (no_broadcast[j]) {
        CommonScalar(args, tid, j);
      } else {
        DivmodScalar(args, tid, j);
      }
    }
  }

  __device__ __forceinline__ void StoreVector(DimsVec args[], int tid) {
    DimsVec *vec_out = reinterpret_cast<DimsVec *>(out_data);
    vec_out[tid] = args[0];
  }

  __device__ __forceinline__ void StoreScalar(T args[], int tid) {
    out_data[scalar_offset + tid] = args[0];
  }
};

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET>
__device__ inline void ScalarizedBroadcastKernelImpl(
    BroadcastArgsWarpper data_transfer, int tid) {
  T args[ET];
  data_transfer.LoadScalar(args, tid);

#pragma unroll(ET)
  for (int j = 1; j < ET; ++j) {
    args[0] += args[j];
  }
  data_transfer.StoreScalar(args, tid);
}

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
          int VecSize>
__device__ inline void VectorizedBroadcastKernelImpl(
    BroadcastArgsWarpper data_transfer, int tid) {
  using VecT = CudaAlignedVector<T, VecSize>;
  VecT args[ET];
  data_transfer.LoadVector(args, tid);

#pragma unroll(ET)
  for (int j = 1; j < ET; ++j) {
#pragma unroll(VecSize)
    for (int i = 0; i < VecSize; ++i) {
      args[0].val[i] += args[j].val[i];
    }
  }
  data_transfer.StoreVector(args, tid);
}

template <typename T, typename BroadcastArgsWarpper, ElementwiseType ET,
          int VecSize>
__global__ void ElementwiseBroadcastKernel(BroadcastArgsWarpper data_transfer,
                                           int main_tid, int tail_tid) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;

  // Aimming at vectorized calculation of major data whose length is max
  // multipler of VecSize.
  if (tid < main_tid) {
    VectorizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET, VecSize>(
        data_transfer, tid);
  }
  // Aimming at scalar calculation of rest data whose lenght cannot fulfill
  // VecSize.
  if (tid < tail_tid) {
    ScalarizedBroadcastKernelImpl<T, BroadcastArgsWarpper, ET>(data_transfer,
                                                               tid);
  }
}

template <typename T, ElementwiseType ET, int VecSize = 1>
void LaunchBroadcastKernelForDifferentDimSize(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
    int axis) {
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_tid = numel / VecSize;
  int tail_tid = numel % VecSize;
  int vec_len = main_tid * VecSize;
  auto stream = ctx.stream();

  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
  const auto offset_calculator = CalculateInputStrides(
      merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims);

  switch (merge_dims.dim_size) {
    case 1: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 1>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 2: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 2>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 3: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 3>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 4: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 4>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 5: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 5>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 6: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 6>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 7: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 7>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    case 8: {
      auto data_transfer = BroadcastArgsWarpper<T, ET, VecSize, 8>(
          ins, offset_calculator, out, vec_len);
      ElementwiseBroadcastKernel<T, decltype(data_transfer), ET,
                                 VecSize><<<blocks, threads, 0, stream>>>(
          data_transfer, main_tid, tail_tid);
      break;
    }
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
  }
}

template <ElementwiseType ET, typename T, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
    Functor func, int axis) {
  int in_vec_size = 4;
  for (auto *in : ins) {
    auto temp_size = GetVectorizedSizeImpl<T>(in->data<T>());
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
  int out_vec_size = GetVectorizedSizeImpl<T>(out->data<T>());
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 4>(ctx, ins, out, axis);
      break;
    }
    case 2: {
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 2>(ctx, ins, out, axis);
      break;
    }
    default: {
      LaunchBroadcastKernelForDifferentDimSize<T, ET, 1>(ctx, ins, out, axis);
      break;
    }
  }
}

}  // namespace operators
}  // namespace paddle