broadcast_function.h 39.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <sstream>
18
#include "paddle/phi/kernels/funcs/elementwise_base.h"
19

20
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
21
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
22

23
namespace kps = phi::kps;
24 25 26

#endif

27
namespace phi {
28 29
namespace funcs {

30 31
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)

32 33
enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 };

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
template <int Index>
struct UseBroadcast {
  template <typename ArgsT, typename Array1, typename Array2>
  static HOSTDEVICE void Apply(
      const std::vector<const DenseTensor *> &ins_tensor,
      const ArgsT &args,
      int64_t numel,
      Array1 *ins_data,
      Array2 *use_broadcast,
      int *broadcast_num,
      bool *all_elementwise) {
    (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data());
    bool is_same_dim = ins_tensor[Index]->numel() == numel;
    if (is_same_dim) {
      (*use_broadcast)[Index] = false;
    } else {
      (*use_broadcast)[Index] = true;
      (*broadcast_num)++;
    }
    *all_elementwise &= is_same_dim;
  }
};

template <typename OutT, int Arity, typename Functor>
58 59 60
struct LoaderTypeClassifier {
 public:
  int64_t numel{0};
61
  int vec_size{4};
62 63
  int broadcast_num{0};
  bool all_elementwise{true};
64 65
  phi::Array<bool, Arity> use_broadcast;
  phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
66 67 68 69

  LoaderTypeClassifier() {}
  LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins,
                       std::vector<DenseTensor *> *outs) {
70 71 72
    using Traits = phi::funcs::FunctionTraits<Functor>;
    using ArgsT = typename Traits::ArgsTuple;
    ArgsT arg;
73
    uint64_t out_addr = reinterpret_cast<uint64_t>((*outs)[0]->data<OutT>());
74 75 76

    UnrollerWithoutVecSize<VecSizeGetter, Arity>::step(ins, arg, &vec_size);

77
    for (auto i = 1; i < outs->size(); ++i) {
78 79 80 81 82 83 84
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          phi::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%d-th output tensor`s shape is not.",
              i));
85 86
      out_addr =
          (out_addr | reinterpret_cast<uint64_t>((*outs)[i]->data<OutT>()));
87 88
    }

89 90 91
    vec_size = std::min(
        vec_size,
        phi::GetVectorizedSize<OutT>(reinterpret_cast<OutT *>(out_addr)));
92
    numel = (*outs)[0]->numel();
93 94 95 96 97 98 99
    UnrollerWithoutVecSize<UseBroadcast, Arity>::step(ins,
                                                      arg,
                                                      numel,
                                                      &ins_data,
                                                      &use_broadcast,
                                                      &broadcast_num,
                                                      &all_elementwise);
100
  }
101
};
102

103
// Common broadcast/elementwise Loader.
104
template <int Index, int VecSize, bool IsBoundary, int LoadType>
105
struct BroadcastDataLoader {
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
                                               const uint32_t numel) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));

    if (use_broadcast[Index]) {
      kps::ReadDataBc<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]),
          block_offset,
          configs[Index],
          numel,
          VecSize);
    }
    // NOTE: If use if...else... with condition `use_broadcast[Index]` here,
    // there will be some errs with clang12 while compiling in ROCm.
    // When the compiler is upgraded, if...else... may be used.
    if (!use_broadcast[Index]) {
      kps::ReadData<Type, VecSize, 1, ArgsT, Index, IsBoundary>(
          args,
          reinterpret_cast<const _ptr_ Type *>(ins[Index]) + block_offset,
          num,
          VecSize);
135
    }
136
  }
137 138
};

139 140
/* BroadcastDataLoaders Partial specialization */
#ifndef PADDLE_WITH_XPU_KP
141
// Scalar elementwise Loader with consideration of IsBoundary.
142 143 144 145 146 147 148 149 150 151 152
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
                                               const uint32_t numel) {
    using Type = std::tuple_element_t<Index, ArgsT>;
153 154
    int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
155 156 157 158 159 160
    for (int idx = 0; idx < VecSize; ++idx) {
      std::get<Index>(args[idx]) = static_cast<Type>(1);
      int index = thread_offset + idx;
      if (index < numel) {
        std::get<Index>(args[idx]) =
            reinterpret_cast<const _ptr_ Type *>(ins[Index])[index];
161 162 163 164 165 166
      }
    }
  }
};

// Vectorized elementwise Loader without consideration of IsBoundary.
167 168 169 170 171 172 173 174 175 176 177 178 179
template <int Index, int VecSize>
struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
                                               const uint32_t numel) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    using VecType = phi::kps::details::VectorType<Type, VecSize>;
    VecType vec_temp;
180 181

    int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
182 183 184
    const VecType *__restrict__ vec_input =
        reinterpret_cast<const VecType *__restrict__>(ins[Index]);
    vec_temp = vec_input[thread_offset];
185
#pragma unroll
186 187
    for (int idx = 0; idx < VecSize; ++idx) {
      std::get<Index>(args[idx]) = vec_temp.val[idx];
188 189 190 191 192
    }
  }
};

// Common broadcast data loader.
193 194 195 196 197 198 199 200 201 202 203 204
template <int Index, int VecSize, bool IsBoundary>
struct BroadcastDataLoader<Index, VecSize, IsBoundary, kBroadcast> {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(const Array1 &ins,
                                               ArgsT *args,
                                               const Array2 &configs,
                                               const Array3 &use_broadcast,
                                               const int block_offset,
                                               const int num,
                                               const uint32_t numel) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    uint32_t index_bc[VecSize];
205
#pragma unroll
206 207 208
    for (int k = 0; k < VecSize; ++k) {
      index_bc[k] = 0;
      std::get<Index>(args[k]) = static_cast<Type>(1);
209 210 211 212 213 214
    }

    uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
    for (int k = 0; k < VecSize; ++k) {
      uint32_t idx = thread_offset + k;
215 216
      if (IsBoundary && idx == numel) {
        break;
217 218 219
      }
#pragma unroll
      for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
220
        if (i == configs[0].rank) break;
221 222
        auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
        idx = fast_divmoder.val[0];
223
        index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i];
224 225 226 227
      }
    }

#pragma unroll
228 229 230
    for (int k = 0; k < VecSize; ++k) {
      std::get<Index>(args[k]) =
          reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[k]];
231 232 233
    }
  }
};
234

235
#endif
236

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
// static broadcast unroller
template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
          typename Func,
          bool IsBoundary,
          int LoadType,
          int VecSize,
          int End,
          int Begin = 0>
struct BcUnroller {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&...args) {
    Func<Begin, VecSize, IsBoundary, LoadType>::Apply(
        std::forward<Args>(args)...);
    BcUnroller<Func, IsBoundary, LoadType, VecSize, End, Begin + 1>::step(
        args...);
  }
};

template <template <int Index, int VecSize, bool IsBoundary, int LoadType>
          typename Func,
          bool IsBoundary,
          int LoadType,
          int VecSize,
          int End>
struct BcUnroller<Func, IsBoundary, LoadType, VecSize, End, End> {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&...args) {}
};

template <typename OutT,
267 268 269 270
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
271
          bool IsBoundary,
272
          int LoadType>
273
__device__ void VectorizedBroadcastKernelImpl(
274
    const phi::Array<const _ptr_ char *__restrict__, Arity> &ins,
275
    phi::Array<_ptr_ OutT *, NumOuts> outs,
276
    const phi::Array<bool, Arity> &use_broadcast,
277
    const uint32_t numel,
278
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
279 280
    int num,
    int block_offset,
281
    int read_lens,
282
    Functor func) {
283 284 285
  using Traits = phi::funcs::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;
  __simd__ ArgsT args[VecSize];
286
  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
287

288 289 290 291 292 293 294 295
  BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step(
      ins, args, configs, use_broadcast, block_offset, num, numel);

  SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
                                     VecSize,
                                     Functor,
                                     ArgsT,
                                     Arity>()(func, args, result, read_lens);
296 297 298
  phi::funcs::
      ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
          outs, result, block_offset, num, read_lens);
299 300
}

301
template <typename Functor,
302 303 304
          typename OutT,
          int Arity,
          int NumOuts,
305
          int VecSize,
306
          int LoadType>
307
__global__ void VectorizedBroadcastKernel(
308
    phi::Array<const _ptr_ char *__restrict__, Arity> ins,
309
    phi::Array<_ptr_ OutT *, NumOuts> outs,
310
    phi::Array<bool, Arity> use_broadcast,
311
    uint32_t numel,
312
    phi::Array<kps::details::BroadcastConfig, Arity> configs,
313 314
    int main_offset,
    int tail_tid,
315
    int read_lens,
316
    Functor func) {
317
#ifdef PADDLE_WITH_XPU_KP
318 319
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
  int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
320
  for (; block_offset < main_offset; block_offset += stride) {
321
    VectorizedBroadcastKernelImpl<OutT,
322 323 324 325
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
326
                                  false,
327 328 329 330 331 332 333 334 335
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            BLOCK_NUM_X * read_lens,
                                            block_offset,
                                            read_lens,
                                            func);
336 337 338
  }
  int num = numel - block_offset;
  if (num > 0) {
339
    VectorizedBroadcastKernelImpl<OutT,
340 341 342 343
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
344
                                  true,
345 346 347 348 349 350 351 352 353
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            num,
                                            block_offset,
                                            read_lens,
                                            func);
354 355
  }
#else
356
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
357
  if (block_offset < main_offset) {
358
    VectorizedBroadcastKernelImpl<OutT,
359 360 361 362
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
363
                                  false,
364 365 366 367 368 369 370 371 372
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            BLOCK_NUM_X * VecSize,
                                            block_offset,
                                            read_lens,
                                            func);
373
  } else {
374
    VectorizedBroadcastKernelImpl<OutT,
375 376 377 378
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
379
                                  true,
380 381 382 383 384 385 386 387 388
                                  LoadType>(ins,
                                            outs,
                                            use_broadcast,
                                            numel,
                                            configs,
                                            tail_tid,
                                            block_offset,
                                            read_lens,
                                            func);
389 390 391 392
  }
#endif
}

393
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
394 395 396 397
void LaunchBroadcastKernel(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
398
    Functor func,
399
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
400
    const LoaderTypeClassifier<OutT, Arity, Functor> &loader_classifier) {
401
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
402
  for (int i = 0; i < NumOuts; ++i) {
403
    outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
404 405
  }

406
#ifdef PADDLE_WITH_XPU_KP
407
  int numel = (*outs)[0]->numel();
408 409
  const int threads = 64;
  const int blocks = 8;
410
  int read_lens = configs[0].buf_len;
411
  auto stream = ctx.x_context()->xpu_stream;
412 413
  int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
  int tail_tid = numel % (read_lens * threads);
414

415
  VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false>
416
      <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
417
                                       outs_data,
418
                                       loader_classifier.use_broadcast,
419 420 421 422 423 424
                                       numel,
                                       configs,
                                       main_offset,
                                       tail_tid,
                                       read_lens,
                                       func);
425
#else
426
  const auto &numel = loader_classifier.numel;
427 428
  auto gpu_config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
429
  auto stream = ctx.stream();
430
  auto threads = gpu_config.GetBlockSize();
431
  auto blocks = gpu_config.block_per_grid;
432 433
  int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
  int tail_tid = numel % (VecSize * threads);
434

435
  if (loader_classifier.all_elementwise) {
436
    VectorizedBroadcastKernel<Functor,
437 438 439 440
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
441 442 443 444 445 446 447 448 449 450 451 452
                              kElementwise>
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
                                         outs_data,
                                         loader_classifier.use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         VecSize,
                                         func);
  } else if (loader_classifier.broadcast_num > (Arity >> 1)) {
    constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed;
453
    VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, type_>
454
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
455
                                         outs_data,
456
                                         loader_classifier.use_broadcast,
457 458 459 460
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
461
                                         VecSize,
462 463
                                         func);
  } else {
464
    VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, kMixed>
465
        <<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
466
                                         outs_data,
467
                                         loader_classifier.use_broadcast,
468 469 470 471
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
472
                                         VecSize,
473 474
                                         func);
  }
475
#endif
476 477
}

478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
#ifndef PADDLE_WITH_XPU_KP
HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
    int64_t src_idx,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank) {
  int64_t dst_idx = 0;
  int64_t old_src_idx = src_idx;
  for (int k = 0; k < rank; ++k) {
    auto local_idx = src_idx / src_strides[k + 1];
    src_idx -= local_idx * src_strides[k + 1];

    if (dst_strides[k] != dst_strides[k + 1]) {
      dst_idx += local_idx * dst_strides[k + 1];
    }
  }
  return dst_idx;
}

497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
template <int N>
struct MaxWithOne {
  static constexpr auto kValue = (N >= 1 ? N : 1);
};

template <int Index, int VecSize>
struct ReadVecDataWithInt64Index {
  template <typename Array1, typename Array2, typename Array3, typename ArgsT>
  static __device__ __forceinline__ void Apply(
      const Array1 &in,
      ArgsT *args,
      int64_t idx,
      const Array2 &need_broadcast,
      const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
      const Array3 &dst_strides,
      int rank,
      bool is_boundary) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    if (is_boundary) {
516 517
#pragma unroll
      for (int i = 0; i < VecSize; ++i) {
518 519 520 521 522 523 524 525 526 527 528 529 530
        std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
            idx + i, src_strides, dst_strides[Index], rank)];
      }
    } else {
      if (!need_broadcast[Index]) {
        kps::ReadData<Type, VecSize, 1, ArgsT, Index, false>(
            args, reinterpret_cast<const _ptr_ Type *>(in[Index]) + idx, 1);
      } else {
#pragma unroll
        for (int i = 0; i < VecSize; ++i) {
          std::get<Index>(args[i]) = in[Index][ConvertSrcIdxToDstIdx(
              idx + i, src_strides, dst_strides[Index], rank)];
        }
531 532 533 534 535
      }
    }
  }
};

536
template <typename OutT, typename Functor, int VecSize, int NumIns>
537
__global__ void BroadcastKernelWithInt64Index(
538 539
    const phi::Array<const _ptr_ char *__restrict__, MaxWithOne<NumIns>::kValue>
        &ins,
540 541 542 543 544 545 546 547 548 549 550 551 552
    OutT *out,
    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<NumIns>::kValue> ins_strides,
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides,
    phi::Array<bool, MaxWithOne<NumIns>::kValue> need_broadcasts,
    int rank,
    Functor functor) {
  int64_t numel = out_strides[0];
  int64_t idx =
      (static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
  int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
  int64_t limit = numel - VecSize;

553 554 555 556
  using Traits = phi::funcs::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;

  ArgsT args[VecSize];
557 558
  phi::AlignedVector<OutT, VecSize> out_vec;
  for (; idx <= limit; idx += stride) {
559 560
    Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
        ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, false);
561 562 563

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
564
      out_vec[i] = static_cast<OutT>(Apply(functor, args[i]));
565 566 567 568 569 570 571
    }
    phi::Store<OutT, VecSize>(out_vec, out + idx);
  }

  if (idx < numel) {
    int remain = numel - idx;  // remain is always less than VecSize, therefore
                               // `int` is enough here
572 573
    Unroller<ReadVecDataWithInt64Index, VecSize, NumIns>::step(
        ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, true);
574
    for (int i = 0; i < remain; ++i) {
575
      out_vec[idx + i] = static_cast<OutT>(Apply(functor, args[i]));
576 577 578 579
    }
  }
}

580
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
581 582 583 584 585 586 587 588 589 590 591
struct LaunchBroadcastKernelWithInt64IndexHelper {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    PADDLE_THROW(phi::errors::PermissionDenied(
        "Unreachable code branch. This may be a bug."));
  }
};

592 593
template <typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
594 595 596 597 598 599 600 601 602
                                                 Functor,
                                                 Arity,
                                                 /*NumOuts=*/1,
                                                 VecSize> {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
603 604 605
    phi::Array<const _ptr_ char *__restrict__, MaxWithOne<Arity>::kValue>
        ins_ptrs;
    UnrollerWithoutVecSize<InputSetter, Arity>::step(ins, &ins_ptrs);
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
    auto *out_tensor = (*outs)[0];
    auto *out_ptr = ctx.Alloc<OutT>(out_tensor);

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank>,
               MaxWithOne<Arity>::kValue>
        ins_expand_dims;
    phi::Array<int64_t, phi::DDim::kMaxRank> broadcast_out_dims;
    int rank;
    if (Arity == 1) {
      rank = ins[0]->dims().size();
      for (int i = 0; i < rank; ++i) {
        broadcast_out_dims[i] = ins[0]->dims()[i];
      }
      ins_expand_dims[0] = broadcast_out_dims;
    } else if (Arity >= 2) {
      CalculateBroadcastDims(ins[0]->dims().Get(),
                             ins[1]->dims().Get(),
                             ins[0]->dims().size(),
                             ins[1]->dims().size(),
                             axis,
                             ins_expand_dims[0].GetMutable(),
                             ins_expand_dims[1].GetMutable(),
                             broadcast_out_dims.GetMutable(),
                             &rank);
      for (int i = 2; i < Arity; ++i) {
        auto tmp_dims = broadcast_out_dims;
        phi::Array<int64_t, phi::DDim::kMaxRank> tmp_expand_dims;
        int tmp_rank;
        PADDLE_ENFORCE_GE(rank,
                          ins[i]->dims().size(),
                          phi::errors::InvalidArgument(
                              "Unsupported reverse broadcast when the input "
                              "tensor number is larger than 2."));
        CalculateBroadcastDims(tmp_dims.Get(),
                               ins[i]->dims().Get(),
                               rank,
                               ins[i]->dims().size(),
                               axis,
                               tmp_expand_dims.GetMutable(),
                               ins_expand_dims[i].GetMutable(),
                               broadcast_out_dims.GetMutable(),
                               &tmp_rank);
        PADDLE_ENFORCE_EQ(rank,
                          tmp_rank,
                          phi::errors::InvalidArgument(
                              "Wrong broadcast algorithm. This may be a bug."));
      }
    }

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<Arity>::kValue>
        ins_strides;
    phi::Array<bool, MaxWithOne<Arity>::kValue> need_broadcasts;
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides;
    const auto &out_dims = out_tensor->dims();
    if (rank <= out_dims.size()) {
      out_strides = ShapeToStride(out_dims.Get(), rank);
    } else {
      out_strides = ShapeToStride(broadcast_out_dims.Get(), rank);
    }

    for (int i = 0; i < Arity; ++i) {
      ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank);
      need_broadcasts[i] =
          !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1);
    }

    int64_t numel = out_strides[0];
    auto gpu_config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);

677
    BroadcastKernelWithInt64Index<OutT, Functor, VecSize, Arity>
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
        <<<gpu_config.block_per_grid,
           gpu_config.thread_per_block,
           0,
           ctx.stream()>>>(ins_ptrs,
                           out_ptr,
                           ins_strides,
                           out_strides,
                           need_broadcasts,
                           rank,
                           functor);
  }

 private:
  static void CalculateBroadcastDims(const int64_t *x_dims,
                                     const int64_t *y_dims,
                                     int nx,
                                     int ny,
                                     int axis,
                                     int64_t *x_out_dims,
                                     int64_t *y_out_dims,
                                     int64_t *broadcast_out_dims,
                                     int *length) {
    PADDLE_ENFORCE_GE(
        axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis));
    if (nx == ny) {
      *length = nx;
      for (int i = 0; i < nx; ++i) {
        if (x_dims[i] != y_dims[i]) {
          PADDLE_ENFORCE_EQ(
              x_dims[i] == 1 || y_dims[i] == 1,
              true,
              phi::errors::InvalidArgument("Cannot broadcast input shape where "
                                           "x_dims[%d] = %d, y_dims[%d] = %d.",
                                           i,
                                           x_dims[i],
                                           i,
                                           y_dims[i]));
        }
        broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]);
        x_out_dims[i] = x_dims[i];
        y_out_dims[i] = y_dims[i];
      }
    } else if (nx > ny) {
      *length = nx;
      for (int i = nx - axis; i < ny; ++i) {
        PADDLE_ENFORCE_EQ(
            y_dims[i],
            1,
            phi::errors::InvalidArgument(
                "The trailing Y.shape[%d] should be 1 but got %d.",
                i,
                y_dims[i]));
      }

      for (int i = 0; i < nx; ++i) {
        if (i >= axis && i - axis < ny) {
          if (x_dims[i] != y_dims[i - axis]) {
            PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1,
                              true,
                              phi::errors::InvalidArgument(
                                  "Cannot broadcast input shape where "
                                  "x_dims[%d] = %d, y_dims[%d] = %d.",
                                  i,
                                  x_dims[i],
                                  i - axis,
                                  y_dims[i - axis]));
          }
          broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]);
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = y_dims[i - axis];
        } else {
          broadcast_out_dims[i] = x_dims[i];
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = 1;
        }
      }
    } else {
      CalculateBroadcastDims(y_dims,
                             x_dims,
                             ny,
                             nx,
                             axis,
                             y_out_dims,
                             x_out_dims,
                             broadcast_out_dims,
                             length);
    }
  }

  static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) {
    for (int i = 0; i < rank; ++i) {
      if (x[i] != y[i]) return false;
    }
    return true;
  }

  static phi::Array<int64_t, phi::DDim::kMaxRank + 1> ShapeToStride(
      const int64_t *arr, int rank) {
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> strides;
    strides[rank] = 1;
    for (int i = rank - 1; i >= 0; --i) {
      strides[i] = strides[i + 1] * arr[i];
    }
    return strides;
  }
};
#endif

786 787 788
template <ElementwiseType ET,
          typename OutT,
          typename Functor,
789
          int kArity,
790 791 792 793 794 795 796
          int NumOuts = 1>
void BroadcastKernelForDifferentVecSize(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
797 798 799 800 801 802
#ifndef PADDLE_WITH_XPU_KP
  constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
  bool use_int64_index_kernel =
      kEnabledInt64IndexKernel &&
      (*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
  if (use_int64_index_kernel) {
803 804
    auto loader_classifier =
        LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
805
    switch (loader_classifier.vec_size) {
806
      case VecSizeL: {
807
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
808 809 810 811 812 813 814 815 816 817 818
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeL>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeM: {
819
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
820 821 822 823 824 825 826 827 828 829 830
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeM>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeS: {
831
        LaunchBroadcastKernelWithInt64IndexHelper<OutT,
832 833 834 835 836 837 838 839 840 841 842 843
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeS>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      default: {
        PADDLE_THROW(phi::errors::Unimplemented(
844
            "Unsupported vectorized size: %d!", loader_classifier.vec_size));
845 846 847 848 849 850 851
        break;
      }
    }
    return;
  }
#endif

852 853 854 855 856 857 858
  phi::Array<kps::details::BroadcastConfig, kArity> configs;
#ifdef PADDLE_WITH_XPU_KP
  PADDLE_ENFORCE_EQ(
      ins.size(),
      2,
      phi::errors::InvalidArgument(
          "XPU only support inputs is 2, but received %d", ins.size()));
859

860
  auto loader_classifier = LoaderTypeClassifier<OutT, kArity, Functor>();
861 862 863 864 865 866
  const auto dims_simplifier =
      BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
  if (VLOG_IS_ON(6)) {
    DimsSimplifiedLogger<int64_t>::Log(
        ins, outs, dims_simplifier, "XPU Broadcast");
  }
867 868 869 870 871 872 873 874
  configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.rank);
  configs[1] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.rank);
875 876 877 878
  auto type = kps::details::OptType::CanNotOptimize;
  bool is_optimize = configs[0].cmp_type != type;
  int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
879 880
  auto loader_classifier =
      LoaderTypeClassifier<OutT, kArity, Functor>(ins, outs);
881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897
  if (!loader_classifier.all_elementwise) {
    const auto dims_simplifier =
        BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);

    if (VLOG_IS_ON(6)) {
      DimsSimplifiedLogger<int64_t>::Log(
          ins, outs, dims_simplifier, "GPU Broadcast");
    }
    for (int i = 0; i < kArity; ++i) {
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
      // if (ins[i]->numel() != (*outs)[0]->numel()) {
      if (ins[i]->numel()) {
        configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                                   dims_simplifier.in_dims[i],
                                                   dims_simplifier.rank);
      }
898
    }
899
  }
900
#endif
901
  switch (loader_classifier.vec_size) {
902
    case VecSizeL: {
903
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeL>(
904
          ctx, ins, outs, func, configs, loader_classifier);
905 906
      break;
    }
907
    case VecSizeM: {
908
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeM>(
909
          ctx, ins, outs, func, configs, loader_classifier);
910 911
      break;
    }
912
    case VecSizeS: {
913
      LaunchBroadcastKernel<OutT, Functor, kArity, NumOuts, VecSizeS>(
914
          ctx, ins, outs, func, configs, loader_classifier);
915 916 917
      break;
    }
    default: {
918
      PADDLE_THROW(phi::errors::Unimplemented(
919
          "Unsupported vectorized size: %d!", loader_classifier.vec_size));
920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
      break;
    }
  }
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx,
                     const std::vector<const DenseTensor *> &ins,
                     std::vector<DenseTensor *> *outs,
                     int axis,
                     Functor func) {
935 936
  // When there are multiple inputs, the outputs's rank should be equal the
  // maximum rank of all inputs.
937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
  using Traits = phi::funcs::FunctionTraits<Functor>;
  const int kArity = Traits::arity;
  PADDLE_ENFORCE_EQ(
      ins.size(),
      kArity,
      phi::errors::InvalidArgument("The number of inputs is expected to be "
                                   "equal to the "
                                   "arity of functor. But received: the "
                                   "number of inputs "
                                   "is %d, the arity of functor is %d.",
                                   ins.size(),
                                   kArity));
  PADDLE_ENFORCE_EQ(
      outs->size(),
      NumOuts,
      phi::errors::InvalidArgument("Number of outputs shall equal to number "
                                   "of functions, "
                                   "but number of outputs is %d, of "
                                   "functions is %d.",
                                   outs->size(),
                                   NumOuts));

959 960
  int max_rank = 0;
  int min_rank = phi::DDim::kMaxRank;
961
  for (auto *in : ins) {
962 963
    max_rank = std::max(max_rank, in->dims().size());
    min_rank = std::min(min_rank, in->dims().size());
964
  }
965 966 967 968 969 970
  if (ins.size() == 1) {
    // When there is only 1 input, the input's rank may be less than outputs'
    // rank.
    max_rank = std::max(max_rank, (*outs)[0]->dims().size());
  }
  axis = axis == -1 ? max_rank - min_rank : axis;
971
  BroadcastKernelForDifferentVecSize<ET, OutT, Functor, kArity, NumOuts>(
972
      ctx, ins, outs, axis, func);
973 974
}

975 976 977 978 979 980 981 982 983
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
                        const DenseTensor &x,
                        const DenseTensor &y,
                        int axis,
                        Functor func,
                        DenseTensor *z) {
  std::vector<const DenseTensor *> ins = {&x, &y};
  std::vector<DenseTensor *> outs = {z};
984
  dev_ctx.template Alloc<OutType>(z);
985

986 987 988 989
  BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
      dev_ctx, ins, &outs, axis, func);
}

990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
}

#else
1006

1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  if (x_dims.size() >= y_dims.size()) {
    funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
  } else {
    funcs::ElementwiseCompute<InverseFunctor, T>(
        dev_ctx, x, y, axis, InverseFunctor(), z);
  }
}
1026 1027
#endif

1028
}  // namespace funcs
1029
}  // namespace phi