broadcast_function.h 39.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <sstream>
18
#include "paddle/phi/kernels/funcs/elementwise_base.h"
19

20
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
21
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
22

23
namespace kps = phi::kps;
24 25 26

#endif

27
namespace phi {
28 29
namespace funcs {

30 31
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)

32 33
enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 };

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
template <int Index>
struct UseBroadcast {
  template <typename ArgsT, typename Array1, typename Array2>
  static HOSTDEVICE void Apply(
      const std::vector<const DenseTensor *> &ins_tensor,
      const ArgsT &args,
      int64_t numel,
      Array1 *ins_data,
      Array2 *use_broadcast,
      int *broadcast_num,
      bool *all_elementwise) {
    (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data());
    bool is_same_dim = ins_tensor[Index]->numel() == numel;
    if (is_same_dim) {
      (*use_broadcast)[Index] = false;
    } else {
      (*use_broadcast)[Index] = true;
      (*broadcast_num)++;
    }
    *all_elementwise &= is_same_dim;
  }
};

template <typename OutT, int Arity, typename Functor>
58 59 60
struct LoaderTypeClassifier {
 public:
  int64_t numel{0};
61
  int vec_size{4};
62 63
  int broadcast_num{0};
  bool all_elementwise{true};
64 65
  phi::Array<bool, Arity> use_broadcast;
  phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
66 67 68 69

  LoaderTypeClassifier() {}
  LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins,
                       std::vector<DenseTensor *> *outs) {
70 71 72
    using Traits = phi::funcs::FunctionTraits<Functor>;
    using ArgsT = typename Traits::ArgsTuple;
    ArgsT arg;
73
    uint64_t out_addr = reinterpret_cast<uint64_t>((*outs)[0]->data<OutT>());
74 75 76

    UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);

77
    for (auto i = 1; i < outs->size(); ++i) {
78 79 80 81 82 83 84
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          phi::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%d-th output tensor`s shape is not.",
              i));
85 86
      out_addr =
          (out_addr | reinterpret_cast<uint64_t>((*outs)[i]->data<OutT>()));
87 88
    }

89 90 91
    vec_size = std::min(
        vec_size,
        phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr)));
92
    numel = (*outs)[0]->numel();
93 94 95 96 97 98 99
    UnrollerWithoutVecSize<UseBroadcast, Arity>::step(ins,
                                                      arg,
                                                      numel,
                                                      &ins_data,
                                                      &use_broadcast,
                                                      &broadcast_num,
                                                      &all_elementwise);
100
  }
101
};
102

103
// Common broadcast/elementwise Loader.
104
template <int Index, int VecSize, bool IsBoundary, int LoadType>
105
struct BroadcastDataLoader {
106 107 108 109 110 111 112
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
B
Bo Zhang 已提交
113 114
                                               const uint32_t numel,
                                               int read_lens) {
115
    using Type = std::tuple_element_t<Index, ArgsT>;
B
Bo Zhang 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
#ifdef PADDLE_WITH_XPU_KP
    kps::Init<Type, ArgsT, Index, VecSize>(
        args, static_cast<Type>(1.0f), read_lens);
    if (use_broadcast[Index]) {
      kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]),
          block_offset,
          configs[Index],
          numel,
          read_lens);
    } else {
      kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
          num,
          read_lens);
    }
#else
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
    if (use_broadcast[Index]) {
      kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]),
          block_offset,
          configs[Index],
          numel,
          VecSize);
    }
    // NOTE: If use if...else... with condition `use_broadcast[Index]` here,
    // there will be some errs with clang12 while compiling in ROCm.
    // When the compiler is upgraded, if...else... may be used.
    if (!use_broadcast[Index]) {
      kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
          num,
          VecSize);
154
    }
B
Bo Zhang 已提交
155
#endif
156
  }
157 158
};

159 160
/* BroadcastDataLoaders Partial specialization */
#ifndef PADDLE_WITH_XPU_KP
161
// Scalar elementwise Loader with consideration of IsBoundary.
162 163 164 165 166 167 168 169 170
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
B
Bo Zhang 已提交
171 172
                                               const uint32_t numel,
                                               int read_lens) {
173
    using Type = std::tuple_element_t<Index, ArgsT>;
174 175
    int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
176 177 178 179 180 181
    for (int idx = 0; idx < VecSize; ++idx) {
      std::get<Index>(args[idx]) = static_cast<Type>(1);
      int index = thread_offset + idx;
      if (index < numel) {
        std::get<Index>(args[idx]) =
            reinterpret_cast<const _ptr_ Type *>(ins[Index])[index];
182 183 184 185 186 187
      }
    }
  }
};

// Vectorized elementwise Loader without consideration of IsBoundary.
188 189 190 191 192 193 194 195 196
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
B
Bo Zhang 已提交
197 198
                                               const uint32_t numel,
                                               int read_lens) {
199 200 201
    using Type = std::tuple_element_t<Index, ArgsT>;
    using VecType = phi::kps::details::VectorType<Type, VecSize>;
    VecType vec_temp;
202 203

    int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
204 205 206
    const VecType *__restrict__ vec_input =
        reinterpret_cast<const VecType *__restrict__>(ins[Index]);
    vec_temp = vec_input[thread_offset];
207
#pragma unroll
208 209
    for (int idx = 0; idx < VecSize; ++idx) {
      std::get<Index>(args[idx]) = vec_temp.val[idx];
210 211 212 213
    }
  }
};

214 215 216 217
template <int Index, int VecSize>
struct BroadcastDataInit {
  template <typename ArgsT>
  static __device__ __forceinline__ void Apply(ArgsT *args) {
218
    using Type = std::tuple_element_t<Index, ArgsT>;
219
#pragma unroll
220 221
    for (int k = 0; k < VecSize; ++k) {
      std::get<Index>(args[k]) = static_cast<Type>(1);
222
    }
223 224
  }
};
225

226 227 228 229 230 231 232
template <int Index, int VecSize>
struct BroadcastDataSetter {
  template <typename Array, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array &ins,
                                               ArgsT *args,
                                               uint32_t index_bc[][VecSize]) {
    using Type = std::tuple_element_t<Index, ArgsT>;
233
#pragma unroll
234 235
    for (int k = 0; k < VecSize; ++k) {
      std::get<Index>(args[k]) =
236
          reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[Index][k]];
237 238 239
    }
  }
};
240

241
#endif
242

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
// static broadcast unroller
template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
          typename Func,
          bool IsBoundary,
          int LoadType,
          int VecSize,
          int End,
          int Begin = 0>
struct BcUnroller {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&...args) {
    Func<Begin, VecSize, IsBoundary, LoadType>::Apply(
        std::forward<Args>(args)...);
    BcUnroller<Func, IsBoundary, LoadType, VecSize, End, Begin + 1>::step(
        args...);
  }
};

template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
          typename Func,
          bool IsBoundary,
          int LoadType,
          int VecSize,
          int End>
struct BcUnroller<Func, IsBoundary, LoadType, VecSize, End, End> {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&...args) {}
};

template <typename OutT,
273 274 275 276
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
277
          bool IsBoundary,
278
          int LoadType>
279
__device__ void VectorizedBroadcastKernelImpl(
280
    const phi::Array<const _ptr_ char *__restrict__, Arity> &ins,
281
    phi::Array<_ptr_ OutT *, NumOuts> outs,
282
    const phi::Array<bool, Arity> &use_broadcast,
283
    const uint32_t numel,
284
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
285 286
    int num,
    int block_offset,
287
    int read_lens,
288
    Functor func) {
289 290 291
  using Traits = phi::funcs::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;
  __simd__ ArgsT args[VecSize];
292
  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
293

B
Bo Zhang 已提交
294 295 296 297
#ifdef PADDLE_WITH_XPU_KP
  BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
      ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
#else
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
  if (LoadType == kBroadcast) {
    uint32_t index_bc[Arity][VecSize] = {0};
    Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
    uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
    for (int k = 0; k < VecSize; ++k) {
      uint32_t idx = thread_offset + k;
      if (IsBoundary && idx == numel) break;
#pragma unroll
      for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
        if (i == configs[0].rank) break;
        auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
        idx = fast_divmoder.val[0];
#pragma unroll
        for (int j = 0; j < Arity; ++j) {
          index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
        }
      }
    }
    Unroller<BroadcastDataSetter, VecSize, Arity>::step(ins, args, index_bc);
  } else {
    BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
B
Bo Zhang 已提交
320
        ins, args, configs, use_broadcast, block_offset, num, numel, read_lens);
321
  }
B
Bo Zhang 已提交
322
#endif
323 324 325 326 327
  SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
                                     VecSize,
                                     Functor,
                                     ArgsT,
                                     Arity>()(func, args, result, read_lens);
328 329 330
  phi::funcs::
      ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
          outs, result, block_offset, num, read_lens);
331 332
}

333
template <typename Functor,
334 335 336
          typename OutT,
          int Arity,
          int NumOuts,
337
          int VecSize,
338
          int LoadType>
339
__global__ void VectorizedBroadcastKernel(
340
    phi::Array<const _ptr_ char *__restrict__, Arity> ins,
341
    phi::Array<_ptr_ OutT *, NumOuts> outs,
342
    phi::Array<bool, Arity> use_broadcast,
343
    uint32_t numel,
344
    phi::Array<kps::details::BroadcastConfig, Arity> configs,
345 346
    int main_offset,
    int tail_tid,
347
    int read_lens,
348
    Functor func) {
349
#ifdef PADDLE_WITH_XPU_KP
350 351
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
  int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
352
  for (; block_offset < main_offset; block_offset += stride) {
353
    VectorizedBroadcastKernelImpl<OutT,
354 355 356 357
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
358
                                  false,
359 360 361 362 363 364 365 366 367
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            BLOCK_NUM_X * read_lens,
                                            block_offset,
                                            read_lens,
                                            func);
368 369 370
  }
  int num = numel - block_offset;
  if (num > 0) {
371
    VectorizedBroadcastKernelImpl<OutT,
372 373 374 375
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
376
                                  true,
377 378 379 380 381 382 383 384 385
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            num,
                                            block_offset,
                                            read_lens,
                                            func);
386 387
  }
#else
388
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
389
  if (block_offset < main_offset) {
390
    VectorizedBroadcastKernelImpl<OutT,
391 392 393 394
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
395
                                  false,
396 397 398 399 400 401 402 403 404
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            BLOCK_NUM_X * VecSize,
                                            block_offset,
                                            read_lens,
                                            func);
405
  } else {
406
    VectorizedBroadcastKernelImpl<OutT,
407 408 409 410
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
411
                                  true,
412 413 414 415 416 417 418 419 420
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            tail_tid,
                                            block_offset,
                                            read_lens,
                                            func);
421 422 423 424
  }
#endif
}

425
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
426 427 428 429
void LaunchBroadcastKernel(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
430
    Functor func,
431
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
432
    const LoaderTypeClassifier<OutT, Arity, Functor> &loader_classifier) {
433
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
434
  for (int i = 0; i < NumOuts; ++i) {
435
    outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
436 437
  }

438
#ifdef PADDLE_WITH_XPU_KP
439
  int numel = (*outs)[0]->numel();
440 441
  const int threads = 64;
  const int blocks = 8;
442
  int read_lens = configs[0].buf_len;
443
  auto stream = ctx.x_context()->xpu_stream;
444 445
  int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
  int tail_tid = numel % (read_lens * threads);
446

447
  VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false>
448
      <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
449
                                       outs_data,
450
                                       loader_classifier.use_broadcast,
451 452 453 454 455 456
                                       numel,
                                       configs,
                                       main_offset,
                                       tail_tid,
                                       read_lens,
                                       func);
457
#else
458
  const auto &numel = loader_classifier.numel;
459 460
  auto gpu_config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
461
  auto stream = ctx.stream();
462
  auto threads = gpu_config.GetBlockSize();
463
  auto blocks = gpu_config.block_per_grid;
464 465
  int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
  int tail_tid = numel % (VecSize * threads);
466

467
  if (loader_classifier.all_elementwise) {
468
    VectorizedBroadcastKernel<Functor,
469 470 471 472
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
473 474 475 476 477 478 479 480 481 482 483 484
                              kElementwise>
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
                                         outs_data,
                                         loader_classifier.use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         VecSize,
                                         func);
  } else if (loader_classifier.broadcast_num > (Arity >> 1)) {
    constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed;
485
    VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, type_>
486
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
487
                                         outs_data,
488
                                         loader_classifier.use_broadcast,
489 490 491 492
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
493
                                         VecSize,
494 495
                                         func);
  } else {
496
    VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, kMixed>
497
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
498
                                         outs_data,
499
                                         loader_classifier.use_broadcast,
500 501 502 503
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
504
                                         VecSize,
505 506
                                         func);
  }
507
#endif
508 509
}

510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
#ifndef PADDLE_WITH_XPU_KP
HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
    int64_t src_idx,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank) {
  int64_t dst_idx = 0;
  int64_t old_src_idx = src_idx;
  for (int k = 0; k < rank; ++k) {
    auto local_idx = src_idx / src_strides[k + 1];
    src_idx -= local_idx * src_strides[k + 1];

    if (dst_strides[k] != dst_strides[k + 1]) {
      dst_idx += local_idx * dst_strides[k + 1];
    }
  }
  return dst_idx;
}

529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
template <int N>
struct MaxWithOne {
  static constexpr auto kValue = (N >= 1 ? N : 1);
};

template <int Index, int VecSize>
struct ReadVecDataWithInt64Index {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(
      const Array1 &in,
      ArgsT *args,
      int64_t idx,
      const Array2 &need_broadcast,
      const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
      const Array3 &dst_strides,
      int rank,
      bool is_boundary) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    if (is_boundary) {
548 549
#pragma unroll
      for (int i = 0; i < VecSize; ++i) {
550 551 552 553 554 555 556 557 558 559 560 561 562
        std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
            idx + i, src_strides, dst_strides[Index], rank)];
      }
    } else {
      if (!need_broadcast[Index]) {
        kps::ReadData<Type, VecSize, 1, ArgsT, Index, false>(
            args, reinterpret_cast<const _ptr_ Type *>(in[Index]) + idx, 1);
      } else {
#pragma unroll
        for (int i = 0; i < VecSize; ++i) {
          std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
              idx + i, src_strides, dst_strides[Index], rank)];
        }
563 564 565 566 567
      }
    }
  }
};

568
template <typename OutT, typename Functor, int VecSize, int NumIns>
569
__global__ void BroadcastKernelWithInt64Index(
570 571
    const phi::Array<const _ptr_ char *__restrict__, MaxWithOne<NumIns>::kValue>
        &ins,
572 573 574 575 576 577 578 579 580 581 582 583 584
    OutT *out,
    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<NumIns>::kValue> ins_strides,
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides,
    phi::Array<bool, MaxWithOne<NumIns>::kValue> need_broadcasts,
    int rank,
    Functor functor) {
  int64_t numel = out_strides[0];
  int64_t idx =
      (static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
  int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
  int64_t limit = numel - VecSize;

585 586 587 588
  using Traits = phi::funcs::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;

  ArgsT args[VecSize];
589 590
  phi::AlignedVector<OutT, VecSize> out_vec;
  for (; idx <= limit; idx += stride) {
591 592
    Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
        ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, false);
593 594 595

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
596
      out_vec[i] = static_cast<OutT>(Apply(functor, args[i]));
597 598 599 600 601 602 603
    }
    phi::Store<OutT, VecSize>(out_vec, out + idx);
  }

  if (idx < numel) {
    int remain = numel - idx;  // remain is always less than VecSize, therefore
                               // `int` is enough here
604 605
    Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
        ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, true);
606
    for (int i = 0; i < remain; ++i) {
607
      out_vec[idx + i] = static_cast<OutT>(Apply(functor, args[i]));
608 609 610 611
    }
  }
}

612
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
613 614 615 616 617 618 619 620 621 622 623
struct LaunchBroadcastKernelWithInt64IndexHelper {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    PADDLE_THROW(phi::errors::PermissionDenied(
        "Unreachable code branch. This may be a bug."));
  }
};

624 625
template <typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
626 627 628 629 630 631 632 633 634
                                                 Functor,
                                                 Arity,
                                                 /*NumOuts=*/1,
                                                 VecSize> {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
635 636 637
    phi::Array<const _ptr_ char *__restrict__, MaxWithOne<Arity>::kValue>
        ins_ptrs;
    UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_ptrs);
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708
    auto *out_tensor = (*outs)[0];
    auto *out_ptr = ctx.Alloc<OutT>(out_tensor);

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank>,
               MaxWithOne<Arity>::kValue>
        ins_expand_dims;
    phi::Array<int64_t, phi::DDim::kMaxRank> broadcast_out_dims;
    int rank;
    if (Arity == 1) {
      rank = ins[0]->dims().size();
      for (int i = 0; i < rank; ++i) {
        broadcast_out_dims[i] = ins[0]->dims()[i];
      }
      ins_expand_dims[0] = broadcast_out_dims;
    } else if (Arity >= 2) {
      CalculateBroadcastDims(ins[0]->dims().Get(),
                             ins[1]->dims().Get(),
                             ins[0]->dims().size(),
                             ins[1]->dims().size(),
                             axis,
                             ins_expand_dims[0].GetMutable(),
                             ins_expand_dims[1].GetMutable(),
                             broadcast_out_dims.GetMutable(),
                             &rank);
      for (int i = 2; i < Arity; ++i) {
        auto tmp_dims = broadcast_out_dims;
        phi::Array<int64_t, phi::DDim::kMaxRank> tmp_expand_dims;
        int tmp_rank;
        PADDLE_ENFORCE_GE(rank,
                          ins[i]->dims().size(),
                          phi::errors::InvalidArgument(
                              "Unsupported reverse broadcast when the input "
                              "tensor number is larger than 2."));
        CalculateBroadcastDims(tmp_dims.Get(),
                               ins[i]->dims().Get(),
                               rank,
                               ins[i]->dims().size(),
                               axis,
                               tmp_expand_dims.GetMutable(),
                               ins_expand_dims[i].GetMutable(),
                               broadcast_out_dims.GetMutable(),
                               &tmp_rank);
        PADDLE_ENFORCE_EQ(rank,
                          tmp_rank,
                          phi::errors::InvalidArgument(
                              "Wrong broadcast algorithm. This may be a bug."));
      }
    }

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<Arity>::kValue>
        ins_strides;
    phi::Array<bool, MaxWithOne<Arity>::kValue> need_broadcasts;
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides;
    const auto &out_dims = out_tensor->dims();
    if (rank <= out_dims.size()) {
      out_strides = ShapeToStride(out_dims.Get(), rank);
    } else {
      out_strides = ShapeToStride(broadcast_out_dims.Get(), rank);
    }

    for (int i = 0; i < Arity; ++i) {
      ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank);
      need_broadcasts[i] =
          !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1);
    }

    int64_t numel = out_strides[0];
    auto gpu_config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);

709
    BroadcastKernelWithInt64Index<OutT, Functor, VecSize, Arity>
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
        <<<gpu_config.block_per_grid,
           gpu_config.thread_per_block,
           0,
           ctx.stream()>>>(ins_ptrs,
                           out_ptr,
                           ins_strides,
                           out_strides,
                           need_broadcasts,
                           rank,
                           functor);
  }

 private:
  static void CalculateBroadcastDims(const int64_t *x_dims,
                                     const int64_t *y_dims,
                                     int nx,
                                     int ny,
                                     int axis,
                                     int64_t *x_out_dims,
                                     int64_t *y_out_dims,
                                     int64_t *broadcast_out_dims,
                                     int *length) {
    PADDLE_ENFORCE_GE(
        axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis));
    if (nx == ny) {
      *length = nx;
      for (int i = 0; i < nx; ++i) {
        if (x_dims[i] != y_dims[i]) {
          PADDLE_ENFORCE_EQ(
              x_dims[i] == 1 || y_dims[i] == 1,
              true,
              phi::errors::InvalidArgument("Cannot broadcast input shape where "
                                           "x_dims[%d] = %d, y_dims[%d] = %d.",
                                           i,
                                           x_dims[i],
                                           i,
                                           y_dims[i]));
        }
        broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]);
        x_out_dims[i] = x_dims[i];
        y_out_dims[i] = y_dims[i];
      }
    } else if (nx > ny) {
      *length = nx;
      for (int i = nx - axis; i < ny; ++i) {
        PADDLE_ENFORCE_EQ(
            y_dims[i],
            1,
            phi::errors::InvalidArgument(
                "The trailing Y.shape[%d] should be 1 but got %d.",
                i,
                y_dims[i]));
      }

      for (int i = 0; i < nx; ++i) {
        if (i >= axis && i - axis < ny) {
          if (x_dims[i] != y_dims[i - axis]) {
            PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1,
                              true,
                              phi::errors::InvalidArgument(
                                  "Cannot broadcast input shape where "
                                  "x_dims[%d] = %d, y_dims[%d] = %d.",
                                  i,
                                  x_dims[i],
                                  i - axis,
                                  y_dims[i - axis]));
          }
          broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]);
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = y_dims[i - axis];
        } else {
          broadcast_out_dims[i] = x_dims[i];
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = 1;
        }
      }
    } else {
      CalculateBroadcastDims(y_dims,
                             x_dims,
                             ny,
                             nx,
                             axis,
                             y_out_dims,
                             x_out_dims,
                             broadcast_out_dims,
                             length);
    }
  }

  static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) {
    for (int i = 0; i < rank; ++i) {
      if (x[i] != y[i]) return false;
    }
    return true;
  }

  static phi::Array<int64_t, phi::DDim::kMaxRank + 1> ShapeToStride(
      const int64_t *arr, int rank) {
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> strides;
    strides[rank] = 1;
    for (int i = rank - 1; i >= 0; --i) {
      strides[i] = strides[i + 1] * arr[i];
    }
    return strides;
  }
};
#endif

818
template <typename OutT, typename Functor, int kArity, int NumOuts = 1>
819 820 821 822 823 824
void BroadcastKernelForDifferentVecSize(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
825 826 827 828 829 830
#ifndef PADDLE_WITH_XPU_KP
  constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
  bool use_int64_index_kernel =
      kEnabledInt64IndexKernel &&
      (*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
  if (use_int64_index_kernel) {
831 832
    auto loader_classifier =
        LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
833
    switch (loader_classifier.vec_size) {
834
      case VecSizeL: {
835
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
836 837 838 839 840 841 842 843 844 845 846
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeL>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeM: {
847
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
848 849 850 851 852 853 854 855 856 857 858
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeM>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeS: {
859
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
860 861 862 863 864 865 866 867 868 869 870 871
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeS>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      default: {
        PADDLE_THROW(phi::errors::Unimplemented(
872
            "Unsupported vectorized size: %d!", loader_classifier.vec_size));
873 874 875 876 877 878 879
        break;
      }
    }
    return;
  }
#endif

880 881 882 883 884 885 886
  phi::Array<kps::details::BroadcastConfig, kArity> configs;
#ifdef PADDLE_WITH_XPU_KP
  PADDLE_ENFORCE_EQ(
      ins.size(),
      2,
      phi::errors::InvalidArgument(
          "XPU only support inputs is 2, but received %d", ins.size()));
887

888
  auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
889 890 891 892 893 894
  const auto dims_simplifier =
      BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
  if (VLOG_IS_ON(6)) {
    DimsSimplifiedLogger<int64_t>::Log(
        ins, outs, dims_simplifier, "XPU Broadcast");
  }
895 896 897 898 899 900 901 902
  configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.rank);
  configs[1] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.rank);
903 904 905 906
  auto type = kps::details::OptType::CanNotOptimize;
  bool is_optimize = configs[0].cmp_type != type;
  int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
907 908
  auto loader_classifier =
      LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925
  if (!loader_classifier.all_elementwise) {
    const auto dims_simplifier =
        BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);

    if (VLOG_IS_ON(6)) {
      DimsSimplifiedLogger<int64_t>::Log(
          ins, outs, dims_simplifier, "GPU Broadcast");
    }
    for (int i = 0; i < kArity; ++i) {
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
      // if (ins[i]->numel() != (*outs)[0]->numel()) {
      if (ins[i]->numel()) {
        configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                                   dims_simplifier.in_dims[i],
                                                   dims_simplifier.rank);
      }
926
    }
927
  }
928
#endif
929
  switch (loader_classifier.vec_size) {
930
    case VecSizeL: {
931
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
932
          ctx, ins, outs, func, configs, loader_classifier);
933 934
      break;
    }
935
    case VecSizeM: {
936
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
937
          ctx, ins, outs, func, configs, loader_classifier);
938 939
      break;
    }
940
    case VecSizeS: {
941
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
942
          ctx, ins, outs, func, configs, loader_classifier);
943 944 945
      break;
    }
    default: {
946
      PADDLE_THROW(phi::errors::Unimplemented(
947
          "Unsupported vectorized size: %d!", loader_classifier.vec_size));
948 949 950 951 952
      break;
    }
  }
}

953
template <typename OutT, typename Functor, int NumOuts = 1>
954 955 956
void BroadcastKernel(const KPDevice &ctx,
                     const std::vector<const DenseTensor *> &ins,
                     std::vector<DenseTensor *> *outs,
957 958
                     Functor func,
                     int axis = -1) {
959 960
  // When there are multiple inputs, the outputs's rank should be equal the
  // maximum rank of all inputs.
961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982
  using Traits = phi::funcs::FunctionTraits<Functor>;
  const int kArity = Traits::arity;
  PADDLE_ENFORCE_EQ(
      ins.size(),
      kArity,
      phi::errors::InvalidArgument("The number of inputs is expected to be "
                                   "equal to the "
                                   "arity of functor. But received: the "
                                   "number of inputs "
                                   "is %d, the arity of functor is %d.",
                                   ins.size(),
                                   kArity));
  PADDLE_ENFORCE_EQ(
      outs->size(),
      NumOuts,
      phi::errors::InvalidArgument("Number of outputs shall equal to number "
                                   "of functions, "
                                   "but number of outputs is %d, of "
                                   "functions is %d.",
                                   outs->size(),
                                   NumOuts));

983 984
  int max_rank = 0;
  int min_rank = phi::DDim::kMaxRank;
985
  for (auto *in : ins) {
986 987
    max_rank = std::max(max_rank, in->dims().size());
    min_rank = std::min(min_rank, in->dims().size());
988
  }
989 990 991 992 993 994
  if (ins.size() == 1) {
    // When there is only 1 input, the input's rank may be less than outputs'
    // rank.
    max_rank = std::max(max_rank, (*outs)[0]->dims().size());
  }
  axis = axis == -1 ? max_rank - min_rank : axis;
995
  BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
996
      ctx, ins, outs, axis, func);
997 998
}

999 1000 1001 1002 1003
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
                        const DenseTensor &x,
                        const DenseTensor &y,
                        Functor func,
1004 1005
                        DenseTensor *z,
                        int axis = -1) {
1006 1007
  std::vector<const DenseTensor *> ins = {&x, &y};
  std::vector<DenseTensor *> outs = {z};
1008
  dev_ctx.template Alloc<OutType>(z);
1009

1010
  BroadcastKernel<OutType, Functor, 1>(dev_ctx, ins, &outs, func, axis);
1011 1012
}

1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
1025
  funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
1026 1027 1028
}

#else
1029

1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  if (x_dims.size() >= y_dims.size()) {
1043
    funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor(), z, axis);
1044 1045
  } else {
    funcs::ElementwiseCompute<InverseFunctor, T>(
1046
        dev_ctx, x, y, InverseFunctor(), z, axis);
1047 1048
  }
}
1049 1050
#endif

1051
}  // namespace funcs
1052
}  // namespace phi