elementwise.h 30.8 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
20
#include "paddle/pten/core/dense_tensor.h"
21
#include "paddle/pten/kernels/funcs/cuda_kernel_config.h"
22 23 24

namespace pten {

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };

/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
   for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using ConditionalT =
    typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;

template <typename InT,
          typename OutT,
          int VecSize,
          typename Functor,
          int Arity,
          bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result);
};

template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
        result, args, func);
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], func);
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], func);
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
  }
};

template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller {
  __device__ __forceinline__ void operator()(
89
      paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
      ConditionalT<OutT, NumOuts> src[VecSize],
      int block_offset,
      int num) {
    OutT dst[NumOuts][VecSize];
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
#pragma unroll
      for (int j = 0; j < NumOuts; ++j) {
        dst[j][i] = (src[i])[j];
      }
    }
#pragma unroll
    for (int i = 0; i < NumOuts; ++i) {
      kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
          outs[i] + block_offset, dst[i], num);
    }
  }
};

template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
  __device__ __forceinline__ void operator()(
112
      paddle::framework::Array<_ptr_ OutT *, 1> outs,
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
      OutT src[VecSize],
      int block_offset,
      int num) {
    kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
        outs[0] + block_offset, src, num);
  }
};

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
          bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(
129 130
    const paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> &in,
    paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    int num,
    int data_offset,
    Functor func) {
  InT args[Arity][VecSize];
  ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
  for (int i = 0; i < Arity; i++) {
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
    kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(
        args[i], in[i] + data_offset, num);
  }

  constexpr bool kCallElementwiseAny =
      paddle::platform::FunctionTraits<Functor>::has_pointer_args;
  ElementwisePrimitiveCaller<InT,
                             ConditionalT<OutT, NumOuts>,
                             VecSize,
                             Functor,
                             Arity,
                             kCallElementwiseAny>()(func, args, result);

  ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
      outs, result, data_offset, num);
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize>
__global__ void VectorizedElementwiseKernel(
164 165
    paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
    paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    int size,
    int main_offset,
    Functor func) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  for (; data_offset < main_offset; data_offset += stride) {
    VectorizedElementwiseKernelImpl<InT,
                                    OutT,
                                    Functor,
                                    Arity,
                                    NumOuts,
                                    VecSize,
                                    false>(
        ins, outs, VecSize * BLOCK_NUM_X, data_offset, func);
  }

  int num = size - data_offset;
  if (num > 0) {
    VectorizedElementwiseKernelImpl<InT,
                                    OutT,
                                    Functor,
                                    Arity,
                                    NumOuts,
                                    VecSize,
                                    true>(ins, outs, num, data_offset, func);
  }
}

template <typename InT, typename OutT>
int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
                                const std::vector<DenseTensor *> &outs) {
  int vec_size = 4;
  for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
    vec_size = std::min<int>(
        vec_size, paddle::platform::GetVectorizedSize((*iter)->data<InT>()));
  }
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
    vec_size = std::min<int>(
        vec_size, paddle::platform::GetVectorizedSize((*iter)->data<OutT>()));
  }
  return vec_size;
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize>
215
void ElementwiseCudaKernel(const KPDevice &ctx,
216 217 218 219
                           const std::vector<const DenseTensor *> &ins,
                           std::vector<DenseTensor *> *outs,
                           Functor func) {
  auto numel = ins[0]->numel();
220 221
  paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
  paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
222 223 224 225 226 227 228 229

  for (int i = 0; i < Arity; ++i) {
    ins_data[i] = ins[i]->data<InT>();
  }
  for (int i = 0; i < NumOuts; ++i) {
    outs_data[i] = (*outs)[i]->mutable_data<OutT>();
  }
#ifdef PADDLE_WITH_XPU2
230 231 232
  int block_size = 64;
  int grid_size = 8;
  auto stream = ctx.x_context()->xpu_stream;
233 234 235 236 237 238 239 240 241
  int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
  VectorizedElementwiseKernel<InT,
                              OutT,
                              Functor,
                              Arity,
                              NumOuts,
                              VecSize><<<grid_size, block_size, 0, stream>>>(
      ins_data, outs_data, numel, main_offset, func);
#else
242 243 244
  int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize);
  int grid_size =
      ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
245
  int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
246
  auto stream = ctx.stream();
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
  VectorizedElementwiseKernel<InT,
                              OutT,
                              Functor,
                              Arity,
                              NumOuts,
                              VecSize><<<grid_size, block_size, 0, stream>>>(
      ins_data, outs_data, numel, main_offset, func);
#endif
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel(
263
    const KPDevice &ctx,
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    Functor func) {
  using Traits = paddle::platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
  PADDLE_ENFORCE_EQ(ins.size(),
                    kArity,
                    paddle::platform::errors::InvalidArgument(
                        "The number of inputs is expected to be equal to the "
                        "arity of functor. But recieved: the number of inputs "
                        "is %d, the arity of functor is %d.",
                        ins.size(),
                        kArity));
  PADDLE_ENFORCE_EQ(outs->size(),
                    NumOuts,
                    paddle::platform::errors::InvalidArgument(
                        "Number of outputs shall equal to number of functions, "
                        "but number of outputs is %d, of functions is %d.",
                        outs->size(),
                        NumOuts));

  if (NumOuts > 1) {
    for (int i = 1; i < NumOuts; ++i) {
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          paddle::platform::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, "
              "but %dth output tensor`s shape is not.",
              i));
    }
  }

  // calculate the max vec_size for all ins and outs
  int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
  switch (vec_size) {
    case 4:
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
          ctx, ins, outs, func);
      break;
    case 2:
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
          ctx, ins, outs, func);
      break;
    case 1:
      ElementwiseCudaKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
          ctx, ins, outs, func);
      break;
    default: {
      PADDLE_THROW(paddle::platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
}

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(
      bool &, std::vector<DimVector> &, DimVector &, int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(paddle::platform::errors::InvalidArgument(
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
                "recieved %d.",
                in_idx + 1,
                axis + 1,
                out_dims[axis],
                in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(paddle::platform::errors::InvalidArgument(
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
                "recieved %d.",
                in_idx + 1,
                in_idx + 1,
                out_dims[in_idx],
                in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx,
                                          vec->begin() + m_idx,
                                          1,
                                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
                               const paddle::framework::DDim &dims,
                               int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = paddle::framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = paddle::framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out,
                                    int i,
                                    int num) {
      for (int j = 1; j < num; ++j) {
433
        equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
434 435 436 437 438 439 440 441 442 443
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out,
                                        int i,
                                        int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
444
          equal &= in_dims[j][i] == out[i];
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
    MergeDimensions<MergeFunctor>(merge_ptr, N);

    int min_idx = 0;
    int min_val = std::accumulate(
        in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(
          in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
    MergeDimensions<MergeFunctor>(merge_ptr, N);
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
    T *dst,
475
    const _ptr_ T *src,
476 477 478 479
    uint32_t block_offset,
    const kps::details::BroadcastConfig<Rank> &config,
    int numel,
    int num,
480
    int need_broadcast) {
481 482 483 484 485 486 487 488 489 490 491 492 493 494
  // numel : whole num of output
  // num: how many data will be deal with in this time
  if (need_broadcast) {
    kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
        dst, src, block_offset, config, numel);
  } else {
    kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
  }
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
495
          int NumOuts,
496 497 498
          int VecSize,
          int Rank,
          bool IsBoundary = false>
499
__device__ void ElementwiseBroadcastKernelImpl(
500 501 502
    const paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> &ins,
    paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
    const paddle::framework::Array<int, Arity> &use_broadcast,
503 504 505 506
    uint32_t numel,
    const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
        &configs,
    int num,
507
    int block_offset,
508 509
    Functor func) {
  InT args[Arity][VecSize];
510
  ConditionalT<OutT, NumOuts> result[VecSize];
511 512 513 514 515 516 517 518 519 520 521 522

#pragma unroll
  for (int i = 0; i < Arity; i++) {
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
    LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
                                             ins[i],
                                             block_offset,
                                             configs[i],
                                             numel,
                                             num,
                                             use_broadcast[i]);
  }
523
  constexpr bool kCallElementwiseAny =
524 525
      paddle::platform::FunctionTraits<Functor>::has_pointer_args;
  ElementwisePrimitiveCaller<InT,
526
                             ConditionalT<OutT, NumOuts>,
527 528 529 530
                             VecSize,
                             Functor,
                             Arity,
                             kCallElementwiseAny>()(func, args, result);
531 532 533

  ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
      outs, result, block_offset, num);
534 535 536 537 538 539
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
540
          int NumOuts,
541 542
          int VecSize,
          int Rank>
543
__global__ void ElementwiseBroadcastKernel(
544 545 546
    paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
    paddle::framework::Array<_ptr_ OutT *, NumOuts> outs,
    paddle::framework::Array<int, Arity> use_broadcast,
547 548 549
    uint32_t numel,
    paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
        configs,
550
    int main_offset,
551 552
    int tail_tid,
    Functor func) {
553 554
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
555

556 557 558 559 560 561
#ifdef PADDLE_WITH_XPU2
  for (; block_offset < main_offset; block_offset += stride) {
    ElementwiseBroadcastKernelImpl<InT,
                                   OutT,
                                   Functor,
                                   Arity,
562
                                   NumOuts,
563 564 565
                                   VecSize,
                                   Rank,
                                   false>(ins,
566
                                          outs,
567 568 569 570 571 572 573
                                          use_broadcast,
                                          numel,
                                          configs,
                                          BLOCK_NUM_X * VecSize,
                                          block_offset,
                                          func);
  }
574 575
  int num = numel - block_offset;
  if (num > 0) {
576 577 578 579
    ElementwiseBroadcastKernelImpl<InT,
                                   OutT,
                                   Functor,
                                   Arity,
580
                                   NumOuts,
581 582 583
                                   VecSize,
                                   Rank,
                                   true>(
584
        ins, outs, use_broadcast, numel, configs, num, block_offset, func);
585
  }
586 587 588 589 590 591
#else
  if (block_offset < main_offset) {
    ElementwiseBroadcastKernelImpl<InT,
                                   OutT,
                                   Functor,
                                   Arity,
592
                                   NumOuts,
593 594 595
                                   VecSize,
                                   Rank,
                                   false>(ins,
596
                                          outs,
597 598 599 600 601 602 603 604 605 606 607
                                          use_broadcast,
                                          numel,
                                          configs,
                                          BLOCK_NUM_X * VecSize,
                                          block_offset,
                                          func);
  } else {
    ElementwiseBroadcastKernelImpl<InT,
                                   OutT,
                                   Functor,
                                   Arity,
608
                                   NumOuts,
609 610 611
                                   VecSize,
                                   Rank,
                                   true>(
612
        ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
613 614
  }
#endif
615 616 617 618 619 620
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
621
          int NumOuts,
622 623
          int VecSize,
          int Rank>
624
void LaunchKernel(const KPDevice &ctx,
625
                  const std::vector<const DenseTensor *> &ins,
626
                  std::vector<DenseTensor *> *outs,
627 628
                  Functor func,
                  DimensionsTransform merge_dims) {
629
  int numel = (*outs)[0]->numel();
630
  paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
631 632 633
  paddle::framework::Array<int, Arity> use_broadcast;
  paddle::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
  paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
634 635 636 637

  for (int i = 0; i < NumOuts; ++i) {
    outs_data[i] = (*outs)[i]->mutable_data<OutT>();
  }
638 639 640

  for (int i = 0; i < Arity; i++) {
    use_broadcast[i] = (ins[i]->numel() != numel);
641
    ins_data[i] = (_ptr_ InT *)(ins[i]->data<InT>());
642 643 644 645 646 647 648 649
    if (use_broadcast[i]) {
      // get the broadcast config,
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
      configs[i] = kps::details::BroadcastConfig<Rank>(
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
    }
  }
650

651
#ifdef PADDLE_WITH_XPU2
652 653 654 655 656
  const int threads = 64;
  const int blocks = 8;
  int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
  int tail_tid = numel % (VecSize * threads);
  auto stream = ctx.x_context()->xpu_stream;
657 658 659 660
  ElementwiseBroadcastKernel<InT,
                             OutT,
                             Functor,
                             Arity,
661
                             NumOuts,
662 663
                             VecSize,
                             Rank><<<blocks, threads, stream>>>(ins_data,
664
                                                                outs_data,
665 666 667 668 669 670 671
                                                                use_broadcast,
                                                                numel,
                                                                configs,
                                                                main_offset,
                                                                tail_tid,
                                                                func);
#else
672 673 674 675 676
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
  int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
  int tail_tid = numel % (VecSize * threads);
  auto stream = ctx.stream();
677 678 679 680
  ElementwiseBroadcastKernel<InT,
                             OutT,
                             Functor,
                             Arity,
681
                             NumOuts,
682 683 684
                             VecSize,
                             Rank><<<blocks, threads, 0, stream>>>(
      ins_data,
685
      outs_data,
686 687 688 689 690 691 692
      use_broadcast,
      numel,
      configs,
      main_offset,
      tail_tid,
      func);
#endif
693 694
}

695 696 697 698 699 700
template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize>
701
void LaunchBroadcastKernelForDifferentVecSize(
702
    const KPDevice &ctx,
703
    const std::vector<const DenseTensor *> &ins,
704
    std::vector<DenseTensor *> *outs,
705 706
    int axis,
    Functor func) {
707
  const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
708

709 710 711 712
#define CALL_BROADCAST_FOR_DIM_SIZE(rank)                            \
  case rank: {                                                       \
    LaunchKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
        ctx, ins, outs, func, merge_dims);                           \
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
  } break;

  switch (merge_dims.dim_size) {
    CALL_BROADCAST_FOR_DIM_SIZE(1);
    CALL_BROADCAST_FOR_DIM_SIZE(2);
    CALL_BROADCAST_FOR_DIM_SIZE(3);
    CALL_BROADCAST_FOR_DIM_SIZE(4);
    CALL_BROADCAST_FOR_DIM_SIZE(5);
    CALL_BROADCAST_FOR_DIM_SIZE(6);
    CALL_BROADCAST_FOR_DIM_SIZE(7);
    CALL_BROADCAST_FOR_DIM_SIZE(8);
    default: {
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size,
          paddle::framework::DDim::kMaxRank));
    }
  }
#undef CALL_BROADCAST_FOR_DIM_SIZE
}

735 736 737 738 739
template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
740
void LaunchBroadcastElementwiseCudaKernel(
741
    const KPDevice &ctx,
742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
  using Traits = paddle::platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
  PADDLE_ENFORCE_EQ(ins.size(),
                    kArity,
                    paddle::platform::errors::InvalidArgument(
                        "The number of inputs is expected to be equal to the "
                        "arity of functor. But recieved: the number of inputs "
                        "is %d, the arity of functor is %d.",
                        ins.size(),
                        kArity));
757
  PADDLE_ENFORCE_LE(kArity,
L
limingshu 已提交
758
                    3,
759
                    paddle::platform::errors::InvalidArgument(
760 761
                        "Currently only broadcast of ternary is supported "
                        "and verified, but received %d.",
762
                        kArity));
763 764 765 766 767 768 769
  PADDLE_ENFORCE_EQ(outs->size(),
                    NumOuts,
                    paddle::platform::errors::InvalidArgument(
                        "Number of outputs shall equal to number of functions, "
                        "but number of outputs is %d, of functions is %d.",
                        outs->size(),
                        NumOuts));
770
  int in_vec_size = 4;
771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
  int out_vec_size = 4;
  if (NumOuts > 1) {
    for (int i = 0; i < NumOuts; ++i) {
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          paddle::platform::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%dth output tensor`s shape is not.",
              i));
      out_vec_size = std::min(
          paddle::platform::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()),
          out_vec_size);
    }
  } else {
    out_vec_size =
        paddle::platform::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
  }

790 791
  for (auto *in : ins) {
    auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>());
792 793 794
    in_vec_size = in->dims() == (*outs)[0]->dims()
                      ? std::min(temp_size, in_vec_size)
                      : in_vec_size;
795 796 797 798 799
  }
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
800 801 802 803 804 805
      LaunchBroadcastKernelForDifferentVecSize<InT,
                                               OutT,
                                               Functor,
                                               kArity,
                                               NumOuts,
                                               4>(ctx, ins, outs, axis, func);
806 807 808
      break;
    }
    case 2: {
809 810 811 812 813 814
      LaunchBroadcastKernelForDifferentVecSize<InT,
                                               OutT,
                                               Functor,
                                               kArity,
                                               NumOuts,
                                               2>(ctx, ins, outs, axis, func);
815 816 817
      break;
    }
    case 1: {
818 819 820 821 822 823
      LaunchBroadcastKernelForDifferentVecSize<InT,
                                               OutT,
                                               Functor,
                                               kArity,
                                               NumOuts,
                                               1>(ctx, ins, outs, axis, func);
824 825 826 827 828 829 830 831 832 833
      break;
    }
    default: {
      PADDLE_THROW(paddle::platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
}

834 835 836 837 838
template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
839 840 841 842 843
void LaunchElementwiseCudaKernel(const KPDevice &ctx,
                                 const std::vector<const DenseTensor *> &ins,
                                 std::vector<DenseTensor *> *outs,
                                 int axis,
                                 Functor func) {
844 845 846 847 848 849 850 851
  std::vector<int> dims_size;
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag &= ins[0]->dims() == in->dims();
    dims_size.emplace_back(in->dims().size());
  }
  if (no_broadcast_flag) {
    LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
852
        ctx, ins, outs, func);
853 854 855 856 857 858
  } else {
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT, Functor, NumOuts>(
859
        ctx, ins, outs, axis, func);
860 861 862
  }
}

863
}  // namespace pten