broadcast_function.h 37.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/kernels/funcs/elementwise_base.h"
18

19
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
20
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
21

22
namespace kps = phi::kps;
23 24 25

#endif

26
namespace phi {
27 28
namespace funcs {

29 30
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)

31
template <typename InT, typename OutT>
32 33 34 35
int GetVecsize(const std::vector<const DenseTensor *> &ins,
               std::vector<DenseTensor *> *outs) {
  int in_vec_size = 4;
  int out_vec_size = 4;
36 37
  if (outs->size() > 1) {
    for (auto i = 1; i < outs->size(); ++i) {
38 39 40 41 42 43 44 45 46 47
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          phi::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%d-th output tensor`s shape is not.",
              i));
      out_vec_size = std::min(
          phi::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()), out_vec_size);
    }
48
  } else {
49 50 51 52 53 54 55 56
    out_vec_size = phi::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
  }

  for (auto *in : ins) {
    auto temp_size = phi::GetVectorizedSize<InT>(in->data<InT>());
    in_vec_size = in->dims() == (*outs)[0]->dims()
                      ? std::min(temp_size, in_vec_size)
                      : in_vec_size;
57
  }
58
  return std::min(out_vec_size, in_vec_size);
59 60
}

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
#ifndef PADDLE_WITH_XPU_KP
template <typename T,
          int VecSize,
          int Arity,
          bool IsBoundary,
          bool is_all_broadcast>
struct BroadcastDataLoader {
  __device__ __forceinline__ void operator()(
      T args[Arity][VecSize],
      const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
      const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
      const phi::Array<int, Arity> &use_broadcast,
      const int block_offset,
      const int num,
      const uint32_t numel) {
#pragma unroll
    for (int i = 0; i < Arity; ++i) {
      kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
      if (use_broadcast[i]) {
        kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
            args[i], ins[i], block_offset, configs[i], numel, VecSize);
      } else {
        kps::ReadData<T, VecSize, 1, IsBoundary>(
            args[i], ins[i] + block_offset, num, VecSize);
      }
    }
87
  }
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
};

template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, true> {
  __device__ __forceinline__ void operator()(
      T args[Arity][VecSize],
      const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
      const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
      const phi::Array<int, Arity> &use_broadcast,
      const int block_offset,
      const int num,
      const uint32_t numel) {
    uint32_t index_bc[Arity][VecSize];
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
#pragma unroll
      for (int k = 0; k < VecSize; ++k) {
        index_bc[j][k] = 0;
        args[j][k] = static_cast<T>(1);
      }
    }

    uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
    for (int k = 0; k < VecSize; ++k) {
      uint32_t idx = thread_offset + k;
      if (IsBoundary) {
        if (idx == numel) break;
      }

#pragma unroll
      for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
120
        if (i == configs[0].rank) break;
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
        idx = fast_divmoder.val[0];
#pragma unroll
        for (int j = 0; j < Arity; ++j) {
          index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
        }
      }
    }

#pragma unroll
    for (int j = 0; j < Arity; ++j) {
#pragma unroll
      for (int k = 0; k < VecSize; ++k) {
        args[j][k] = ins[j][index_bc[j][k]];
      }
    }
  }
};
#endif
140

141 142 143 144 145 146
template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
147 148
          bool IsBoundary,
          bool IsAllBroadcast = false>
149
__device__ void VectorizedBroadcastKernelImpl(
150 151 152
    const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    const phi::Array<int, Arity> &use_broadcast,
153
    const uint32_t numel,
154
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
155 156
    int num,
    int block_offset,
157
    int read_lens,
158
    Functor func) {
159 160
  __simd__ InT args[Arity][VecSize];
  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
161
#ifdef PADDLE_WITH_XPU_KP
162
#pragma unroll
163
  for (int i = 0; i < Arity; ++i) {
164
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
165 166 167 168 169 170 171
    if (use_broadcast[i]) {
      kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
          args[i], ins[i], block_offset, configs[i], numel, read_lens);
    } else {
      kps::ReadData<InT, VecSize, 1, IsBoundary>(
          args[i], ins[i] + block_offset, num, read_lens);
    }
172
  }
173 174 175 176 177
#else
  BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, IsAllBroadcast>()(
      args, ins, configs, use_broadcast, block_offset, num, numel);
#endif

178 179
  constexpr bool kCallElementwiseAny =
      paddle::platform::FunctionTraits<Functor>::has_pointer_args;
180 181 182 183 184 185
  phi::funcs::ElementwisePrimitiveCaller<InT,
                                         ConditionalT<OutT, NumOuts>,
                                         VecSize,
                                         Functor,
                                         Arity,
                                         kCallElementwiseAny>()(
186 187 188 189
      func, args, result, read_lens);
  phi::funcs::
      ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
          outs, result, block_offset, num, read_lens);
190 191
}

192 193
template <typename Functor,
          typename InT,
194 195 196
          typename OutT,
          int Arity,
          int NumOuts,
197 198
          int VecSize,
          bool IsAllBroadcast>
199
__global__ void VectorizedBroadcastKernel(
200 201 202
    phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    phi::Array<int, Arity> use_broadcast,
203
    uint32_t numel,
204
    phi::Array<kps::details::BroadcastConfig, Arity> configs,
205 206
    int main_offset,
    int tail_tid,
207
    int read_lens,
208
    Functor func) {
209
#ifdef PADDLE_WITH_XPU_KP
210 211
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
  int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
212 213 214 215 216 217 218
  for (; block_offset < main_offset; block_offset += stride) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
219 220 221 222 223 224 225 226 227 228
                                  false,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  BLOCK_NUM_X * read_lens,
                                                  block_offset,
                                                  read_lens,
                                                  func);
229 230 231 232 233 234 235 236 237
  }
  int num = numel - block_offset;
  if (num > 0) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
238 239 240 241 242 243 244 245 246 247
                                  true,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  num,
                                                  block_offset,
                                                  read_lens,
                                                  func);
248 249
  }
#else
250
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
251 252 253 254 255 256 257
  if (block_offset < main_offset) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
258 259 260 261 262 263 264 265 266 267
                                  false,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  BLOCK_NUM_X * VecSize,
                                                  block_offset,
                                                  read_lens,
                                                  func);
268 269 270 271 272 273 274
  } else {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
275 276 277 278 279 280 281 282 283 284
                                  true,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  tail_tid,
                                                  block_offset,
                                                  read_lens,
                                                  func);
285 286 287 288 289 290 291 292 293
  }
#endif
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
294 295 296 297 298 299 300
          int VecSize>
void LaunchBroadcastKernel(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    Functor func,
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
301
  int broadcast_num = 0;
302
  int numel = (*outs)[0]->numel();
303 304 305
  phi::Array<int, Arity> use_broadcast;
  phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
306 307

  for (int i = 0; i < NumOuts; ++i) {
308
    outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
309 310
  }

311
  for (int i = 0; i < Arity; ++i) {
312 313 314 315 316 317
    if (ins[i]->numel() != numel) {
      broadcast_num++;
      use_broadcast[i] = true;
    } else {
      use_broadcast[i] = false;
    }
318
    ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
319 320
  }

321
#ifdef PADDLE_WITH_XPU_KP
322 323
  const int threads = 64;
  const int blocks = 8;
324
  int read_lens = configs[0].buf_len;
325
  auto stream = ctx.x_context()->xpu_stream;
326 327
  int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
  int tail_tid = numel % (read_lens * threads);
328 329 330 331 332 333 334 335 336 337 338

  VectorizedBroadcastKernel<Functor, InT, OutT, Arity, NumOuts, VecSize, false>
      <<<blocks, threads, 0, stream>>>(ins_data,
                                       outs_data,
                                       use_broadcast,
                                       numel,
                                       configs,
                                       main_offset,
                                       tail_tid,
                                       read_lens,
                                       func);
339
#else
340 341 342
  auto gpu_config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
  int read_lens = VecSize;
343
  auto stream = ctx.stream();
344 345 346 347 348
  auto threads = gpu_config.thread_per_block;
  auto blocks = gpu_config.block_per_grid;
  int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
                    read_lens * gpu_config.GetBlockSize();
  int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384

  if (broadcast_num > (Arity >> 1)) {
    VectorizedBroadcastKernel<Functor,
                              InT,
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
                              (Arity > 1)>
        <<<blocks, threads, 0, stream>>>(ins_data,
                                         outs_data,
                                         use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         read_lens,
                                         func);
  } else {
    VectorizedBroadcastKernel<Functor,
                              InT,
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
                              false>
        <<<blocks, threads, 0, stream>>>(ins_data,
                                         outs_data,
                                         use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         read_lens,
                                         func);
  }
385
#endif
386 387
}

388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778
#ifndef PADDLE_WITH_XPU_KP
HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
    int64_t src_idx,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank) {
  int64_t dst_idx = 0;
  int64_t old_src_idx = src_idx;
  for (int k = 0; k < rank; ++k) {
    auto local_idx = src_idx / src_strides[k + 1];
    src_idx -= local_idx * src_strides[k + 1];

    if (dst_strides[k] != dst_strides[k + 1]) {
      dst_idx += local_idx * dst_strides[k + 1];
    }
  }
  return dst_idx;
}

template <typename T, int VecSize, bool IsBoundary>
HOSTDEVICE static void ReadVecDataWithInt64Index(
    const T *in,
    int64_t idx,
    bool need_broadcast,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank,
    int n,
    phi::AlignedVector<T, VecSize> *out) {
  if (IsBoundary) {
    for (int i = 0; i < n; ++i) {
      (*out)[i] =
          in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
    }
  } else {
    if (!need_broadcast) {
      phi::Load<T, VecSize>(in + idx, out);
    } else {
#pragma unroll
      for (int i = 0; i < VecSize; ++i) {
        (*out)[i] =
            in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
      }
    }
  }
}

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          int NumIns>
struct ApplyFunctorWithInt64IndexHelper {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i);
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 0> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor());
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 1> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor(ins_vec[0][i]));
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 2> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor(ins_vec[0][i], ins_vec[1][i]));
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 3> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(
        functor(ins_vec[0][i], ins_vec[1][i], ins_vec[2][i]));
  }
};

template <int N>
struct MaxWithOne {
  static constexpr auto kValue = (N >= 1 ? N : 1);
};

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          int NumIns>
__global__ void BroadcastKernelWithInt64Index(
    phi::Array<const InT *, MaxWithOne<NumIns>::kValue> ins,
    OutT *out,
    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<NumIns>::kValue> ins_strides,
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides,
    phi::Array<bool, MaxWithOne<NumIns>::kValue> need_broadcasts,
    int rank,
    Functor functor) {
  int64_t numel = out_strides[0];
  int64_t idx =
      (static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
  int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
  int64_t limit = numel - VecSize;

  phi::Array<phi::AlignedVector<InT, VecSize>, MaxWithOne<NumIns>::kValue>
      ins_vec;
  phi::AlignedVector<OutT, VecSize> out_vec;
  for (; idx <= limit; idx += stride) {
#pragma unroll
    for (int i = 0; i < NumIns; ++i) {
      ReadVecDataWithInt64Index<InT, VecSize, false>(ins[i],
                                                     idx,
                                                     need_broadcasts[i],
                                                     out_strides,
                                                     ins_strides[i],
                                                     rank,
                                                     VecSize,
                                                     &ins_vec[i]);
    }

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      out_vec[i] = ApplyFunctorWithInt64IndexHelper<InT,
                                                    OutT,
                                                    Functor,
                                                    VecSize,
                                                    NumIns>::Run(ins_vec.Get(),
                                                                 functor,
                                                                 i);
    }

    phi::Store<OutT, VecSize>(out_vec, out + idx);
  }

  if (idx < numel) {
    int remain = numel - idx;  // remain is always less than VecSize, therefore
                               // `int` is enough here
#pragma unroll
    for (int i = 0; i < NumIns; ++i) {
      ReadVecDataWithInt64Index<InT, VecSize, true>(ins[i],
                                                    idx,
                                                    need_broadcasts[i],
                                                    out_strides,
                                                    ins_strides[i],
                                                    rank,
                                                    remain,
                                                    &ins_vec[i]);
    }

    for (int i = 0; i < remain; ++i) {
      out[idx + i] =
          ApplyFunctorWithInt64IndexHelper<InT,
                                           OutT,
                                           Functor,
                                           VecSize,
                                           NumIns>::Run(ins_vec.Get(),
                                                        functor,
                                                        i);
    }
  }
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    PADDLE_THROW(phi::errors::PermissionDenied(
        "Unreachable code branch. This may be a bug."));
  }
};

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                 OutT,
                                                 Functor,
                                                 Arity,
                                                 /*NumOuts=*/1,
                                                 VecSize> {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    phi::Array<const InT *, MaxWithOne<Arity>::kValue> ins_ptrs;
    for (int i = 0; i < Arity; ++i) {
      ins_ptrs[i] = ins[i]->data<InT>();
    }
    auto *out_tensor = (*outs)[0];
    auto *out_ptr = ctx.Alloc<OutT>(out_tensor);

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank>,
               MaxWithOne<Arity>::kValue>
        ins_expand_dims;
    phi::Array<int64_t, phi::DDim::kMaxRank> broadcast_out_dims;
    int rank;
    if (Arity == 1) {
      rank = ins[0]->dims().size();
      for (int i = 0; i < rank; ++i) {
        broadcast_out_dims[i] = ins[0]->dims()[i];
      }
      ins_expand_dims[0] = broadcast_out_dims;
    } else if (Arity >= 2) {
      CalculateBroadcastDims(ins[0]->dims().Get(),
                             ins[1]->dims().Get(),
                             ins[0]->dims().size(),
                             ins[1]->dims().size(),
                             axis,
                             ins_expand_dims[0].GetMutable(),
                             ins_expand_dims[1].GetMutable(),
                             broadcast_out_dims.GetMutable(),
                             &rank);
      for (int i = 2; i < Arity; ++i) {
        auto tmp_dims = broadcast_out_dims;
        phi::Array<int64_t, phi::DDim::kMaxRank> tmp_expand_dims;
        int tmp_rank;
        PADDLE_ENFORCE_GE(rank,
                          ins[i]->dims().size(),
                          phi::errors::InvalidArgument(
                              "Unsupported reverse broadcast when the input "
                              "tensor number is larger than 2."));
        CalculateBroadcastDims(tmp_dims.Get(),
                               ins[i]->dims().Get(),
                               rank,
                               ins[i]->dims().size(),
                               axis,
                               tmp_expand_dims.GetMutable(),
                               ins_expand_dims[i].GetMutable(),
                               broadcast_out_dims.GetMutable(),
                               &tmp_rank);
        PADDLE_ENFORCE_EQ(rank,
                          tmp_rank,
                          phi::errors::InvalidArgument(
                              "Wrong broadcast algorithm. This may be a bug."));
      }
    }

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<Arity>::kValue>
        ins_strides;
    phi::Array<bool, MaxWithOne<Arity>::kValue> need_broadcasts;
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides;
    const auto &out_dims = out_tensor->dims();
    if (rank <= out_dims.size()) {
      out_strides = ShapeToStride(out_dims.Get(), rank);
    } else {
      out_strides = ShapeToStride(broadcast_out_dims.Get(), rank);
    }

    for (int i = 0; i < Arity; ++i) {
      ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank);
      need_broadcasts[i] =
          !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1);
    }

    int64_t numel = out_strides[0];
    auto gpu_config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);

    BroadcastKernelWithInt64Index<InT, OutT, Functor, VecSize, Arity>
        <<<gpu_config.block_per_grid,
           gpu_config.thread_per_block,
           0,
           ctx.stream()>>>(ins_ptrs,
                           out_ptr,
                           ins_strides,
                           out_strides,
                           need_broadcasts,
                           rank,
                           functor);
  }

 private:
  static void CalculateBroadcastDims(const int64_t *x_dims,
                                     const int64_t *y_dims,
                                     int nx,
                                     int ny,
                                     int axis,
                                     int64_t *x_out_dims,
                                     int64_t *y_out_dims,
                                     int64_t *broadcast_out_dims,
                                     int *length) {
    PADDLE_ENFORCE_GE(
        axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis));
    if (nx == ny) {
      *length = nx;
      for (int i = 0; i < nx; ++i) {
        if (x_dims[i] != y_dims[i]) {
          PADDLE_ENFORCE_EQ(
              x_dims[i] == 1 || y_dims[i] == 1,
              true,
              phi::errors::InvalidArgument("Cannot broadcast input shape where "
                                           "x_dims[%d] = %d, y_dims[%d] = %d.",
                                           i,
                                           x_dims[i],
                                           i,
                                           y_dims[i]));
        }
        broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]);
        x_out_dims[i] = x_dims[i];
        y_out_dims[i] = y_dims[i];
      }
    } else if (nx > ny) {
      *length = nx;
      for (int i = nx - axis; i < ny; ++i) {
        PADDLE_ENFORCE_EQ(
            y_dims[i],
            1,
            phi::errors::InvalidArgument(
                "The trailing Y.shape[%d] should be 1 but got %d.",
                i,
                y_dims[i]));
      }

      for (int i = 0; i < nx; ++i) {
        if (i >= axis && i - axis < ny) {
          if (x_dims[i] != y_dims[i - axis]) {
            PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1,
                              true,
                              phi::errors::InvalidArgument(
                                  "Cannot broadcast input shape where "
                                  "x_dims[%d] = %d, y_dims[%d] = %d.",
                                  i,
                                  x_dims[i],
                                  i - axis,
                                  y_dims[i - axis]));
          }
          broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]);
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = y_dims[i - axis];
        } else {
          broadcast_out_dims[i] = x_dims[i];
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = 1;
        }
      }
    } else {
      CalculateBroadcastDims(y_dims,
                             x_dims,
                             ny,
                             nx,
                             axis,
                             y_out_dims,
                             x_out_dims,
                             broadcast_out_dims,
                             length);
    }
  }

  static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) {
    for (int i = 0; i < rank; ++i) {
      if (x[i] != y[i]) return false;
    }
    return true;
  }

  static phi::Array<int64_t, phi::DDim::kMaxRank + 1> ShapeToStride(
      const int64_t *arr, int rank) {
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> strides;
    strides[rank] = 1;
    for (int i = rank - 1; i >= 0; --i) {
      strides[i] = strides[i + 1] * arr[i];
    }
    return strides;
  }
};
#endif

779 780 781 782 783 784 785 786 787 788 789 790 791 792
template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernelForDifferentVecSize(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
  using Traits = paddle::platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
  PADDLE_ENFORCE_EQ(
      ins.size(),
      kArity,
      phi::errors::InvalidArgument("The number of inputs is expected to be "
                                   "equal to the "
                                   "arity of functor. But recieved: the "
                                   "number of inputs "
                                   "is %d, the arity of functor is %d.",
                                   ins.size(),
                                   kArity));
  PADDLE_ENFORCE_LE(
      kArity,
      3,
      phi::errors::InvalidArgument("Currently only broadcast of ternary is "
                                   "supported "
                                   "and verified, but received %d.",
                                   kArity));
  PADDLE_ENFORCE_EQ(
      outs->size(),
      NumOuts,
      phi::errors::InvalidArgument("Number of outputs shall equal to number "
                                   "of functions, "
                                   "but number of outputs is %d, of "
                                   "functions is %d.",
                                   outs->size(),
                                   NumOuts));
819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876

#ifndef PADDLE_WITH_XPU_KP
  constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
  bool use_int64_index_kernel =
      kEnabledInt64IndexKernel &&
      (*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
  if (use_int64_index_kernel) {
    int vec_size = GetVecsize<InT, OutT>(ins, outs);
    switch (vec_size) {
      case VecSizeL: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeL>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeM: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeM>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeS: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeS>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      default: {
        PADDLE_THROW(phi::errors::Unimplemented(
            "Unsupported vectorized size: %d!", vec_size));
        break;
      }
    }
    return;
  }
#endif

877
  // mergedim and get vec_size
878 879 880 881 882 883 884 885 886 887 888 889 890
  const auto dims_simplifier =
      BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
  if (VLOG_IS_ON(4)) {
    for (size_t i = 0; i < dims_simplifier.in_dims.size(); ++i) {
      VLOG(4) << "input i=" << i << ": origin_dims={" << ins[i]->dims()
              << "}, simplied_dims={"
              << phi::make_ddim(dims_simplifier.in_dims[i]) << "}";
    }
    VLOG(4) << "output: origin_dims={" << (*outs)[0]->dims()
            << "}, simplied_dims={" << phi::make_ddim(dims_simplifier.out_dims)
            << "}";
  }

891
  phi::Array<kps::details::BroadcastConfig, kArity> configs;
892

893 894 895 896 897 898 899
// get vec_size
#ifdef PADDLE_WITH_XPU_KP
  PADDLE_ENFORCE_EQ(
      ins.size(),
      2,
      phi::errors::InvalidArgument(
          "XPU only support inputs is 2, but received %d", ins.size()));
900 901 902 903 904 905 906 907
  configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.rank);
  configs[1] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                             dims_simplifier.in_dims[1],
                                             dims_simplifier.in_dims[0],
                                             dims_simplifier.rank);
908 909 910 911
  auto type = kps::details::OptType::CanNotOptimize;
  bool is_optimize = configs[0].cmp_type != type;
  int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
912
  for (int i = 0; i < kArity; ++i) {
913 914 915
    // get the broadcast config,
    // if data shape is[m, n], then you should set data_dim = {n, m}
    // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
916
    // if (ins[i]->numel() != (*outs)[0]->numel()) {
917
    if (ins[i]->numel()) {
918 919 920
      configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
                                                 dims_simplifier.in_dims[i],
                                                 dims_simplifier.rank);
921
    }
922
  }
923
  int vec_size = GetVecsize<InT, OutT>(ins, outs);
924
#endif
925 926

  switch (vec_size) {
927 928 929
    case VecSizeL: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>(
          ctx, ins, outs, func, configs);
930 931
      break;
    }
932 933 934
    case VecSizeM: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>(
          ctx, ins, outs, func, configs);
935 936
      break;
    }
937 938 939
    case VecSizeS: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>(
          ctx, ins, outs, func, configs);
940 941 942
      break;
    }
    default: {
943
      PADDLE_THROW(phi::errors::Unimplemented(
944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
          "Unsupported vectorized size: %d!", vec_size));
      break;
    }
  }
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx,
                     const std::vector<const DenseTensor *> &ins,
                     std::vector<DenseTensor *> *outs,
                     int axis,
                     Functor func) {
  std::vector<int> dims_size;
961
  dims_size.reserve(ins.size());
962 963 964 965
  for (auto *in : ins) {
    dims_size.emplace_back(in->dims().size());
  }

966 967 968
  axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) -
                          *std::min_element(dims_size.begin(), dims_size.end())
                    : axis;
969 970
  BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
      ctx, ins, outs, axis, func);
971 972
}

973 974 975 976 977 978 979 980 981 982 983 984 985 986
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
                        const DenseTensor &x,
                        const DenseTensor &y,
                        int axis,
                        Functor func,
                        DenseTensor *z) {
  std::vector<const DenseTensor *> ins = {&x, &y};
  std::vector<DenseTensor *> outs = {z};
  z->mutable_data<OutType>(dev_ctx.GetPlace());
  BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
      dev_ctx, ins, &outs, axis, func);
}

987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
}

#else
1003

1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  if (x_dims.size() >= y_dims.size()) {
    funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
  } else {
    funcs::ElementwiseCompute<InverseFunctor, T>(
        dev_ctx, x, y, axis, InverseFunctor(), z);
  }
}

1024 1025
#endif

1026
}  // namespace funcs
1027
}  // namespace phi