broadcast_function.h 43.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/kernels/funcs/elementwise_base.h"
18

19
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
20

21
namespace kps = phi::kps;
22 23 24

#endif

25
namespace phi {
26 27
namespace funcs {

28 29
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)

30 31 32 33
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(
      bool &, std::vector<DimVector> &, DimVector &, int, int);
34
  int64_t N;
35 36 37 38 39
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
40 41
  // To compensate the lackage of input_tensors` dimension with input
  // variable 'axis'.
42 43 44 45 46
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
47
        for (; in_idx < in_dim.size();) {
48 49 50 51 52
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
53
            PADDLE_THROW(phi::errors::InvalidArgument(
54 55
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
56
                "received %d.",
57 58 59 60 61
                in_idx + 1,
                axis + 1,
                out_dims[axis],
                in_dim[in_idx]));
          }
62
        }
63 64 65
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
66
        for (; in_idx < dim_size;) {
67 68 69
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
70
            PADDLE_THROW(phi::errors::InvalidArgument(
71 72
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
73
                "received %d.",
74 75 76 77 78
                in_idx + 1,
                in_idx + 1,
                out_dims[in_idx],
                in_dim[in_idx]));
          }
79
        }
80 81 82 83 84 85
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

86 87
  // Merge sequential dimension to shrink calculation cost for
  // offset computation in CUDA Kernel.
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  template <typename MergeFunctor>
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx,
                                          vec->begin() + m_idx,
                                          1,
                                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

126 127
  // To judge whether shape of any input tensors is sequential
  // 1-value-dimensions, and metric the length of it.
128
  bool FindSequentialOneDim(int *swap_index) {
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    int index = 0;
    int max_one_length = 0;
    for (int j = 0; j < N; ++j) {
      int seq_one_length = 0;
      bool active_seq = false;

      for (int i = 0; i < dim_size; ++i) {
        if (!active_seq && in_dims[j][i] == 1) {
          seq_one_length = 1;
          active_seq = true;
        } else if (active_seq) {
          if (in_dims[j][i] == 1) {
            seq_one_length++;
          } else {
            active_seq = false;
          }
        }
      }
      index = seq_one_length > max_one_length ? j : index;
148
      max_one_length = std::max(seq_one_length, max_one_length);
149 150
    }

151 152
    bool has_seq_one = max_one_length > 1;
    if (has_seq_one) {
153 154 155
      std::swap(in_dims[0], in_dims[index]);
      *swap_index = index;
    }
156
    return has_seq_one;
157 158
  }

159 160
 public:
  explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
161
                               const phi::DDim &dims,
162
                               int axis) {
163
    N = std::max(static_cast<int>(ins.size()), 2);
164
    dim_size = dims.size();
165
    out_dims = phi::vectorize<int64_t>(dims);
166 167 168
    in_dims.resize(N);
    if (ins.size() == 1) {
      // when ins.size() = 1, broadcast input to output
169
      in_dims[0] = phi::vectorize<int64_t>(ins[0]->dims());
170 171 172 173
      in_dims[1] = out_dims;
      // Add out_dims to in_dims to avoid errors in dims merging
    } else {
      for (int j = 0; j < N; ++j) {
174
        in_dims[j] = phi::vectorize<int64_t>(ins[j]->dims());
175 176 177 178
      }
    }
    InputDimensionsExtend(N, axis);

179 180 181 182 183
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears. Example below :
    //   in_1.shape = [2, 3, 4, 5]    in_1.shape = [2, 12, 5]
    //   in_2.shape = [1, 3, 4, 5] -> in_2.shape = [1, 12, 5]
    //   in_3.shape = [2, 3, 4, 1]    in_3.shape = [2, 12, 1]
184 185 186 187 188 189 190 191 192
    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out,
                                    int i,
                                    int num) {
      for (int j = 1; j < num; ++j) {
        equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
193 194 195 196 197 198 199 200 201 202 203
    MergeFunctor merge_ptr = merge_sequential_dims;
    MergeDimensions<MergeFunctor>(merge_ptr, N);

    // To Merge the dimension of input_tensors while the sequential
    // 1-value-dimensions appears. Example below :
    //   in_1.shape = [2, 1, 1, 5]    in_1.shape = [2,  1, 5]
    //   in_2.shape = [2, 3, 4, 5] -> in_2.shape = [1, 12, 5]
    //   in_3.shape = [2, 3, 4, 1]    in_3.shape = [2, 12, 1]
    // Caution: Once 1-value-dimensions appears, the corresponding
    // shape position of other input tensors must be same with the
    // output tensor`s shape, or incorrect merge may occur.
204 205 206 207 208 209 210 211 212 213 214 215
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out,
                                        int i,
                                        int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal &= in_dims[j][i] == out[i];
        }
      }
    };
216 217 218 219
    for (auto i = 0; i < dim_size; ++i) {
      int swap_idx = 0;
      bool has_seq_one = FindSequentialOneDim(&swap_idx);
      if (!has_seq_one) break;
220 221 222
      merge_ptr = merge_sequential_one_dims;
      MergeDimensions<MergeFunctor>(merge_ptr, N);
      std::swap(in_dims[swap_idx], in_dims[0]);
223 224 225 226
    }
  }
};

227
template <typename InT, typename OutT>
228 229 230 231
int GetVecsize(const std::vector<const DenseTensor *> &ins,
               std::vector<DenseTensor *> *outs) {
  int in_vec_size = 4;
  int out_vec_size = 4;
232 233
  if (outs->size() > 1) {
    for (auto i = 1; i < outs->size(); ++i) {
234 235 236 237 238 239 240 241 242 243
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          phi::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%d-th output tensor`s shape is not.",
              i));
      out_vec_size = std::min(
          phi::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()), out_vec_size);
    }
244
  } else {
245 246 247 248 249 250 251 252
    out_vec_size = phi::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
  }

  for (auto *in : ins) {
    auto temp_size = phi::GetVectorizedSize<InT>(in->data<InT>());
    in_vec_size = in->dims() == (*outs)[0]->dims()
                      ? std::min(temp_size, in_vec_size)
                      : in_vec_size;
253
  }
254
  return std::min(out_vec_size, in_vec_size);
255 256
}

257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
#ifndef PADDLE_WITH_XPU_KP
template <typename T,
          int VecSize,
          int Arity,
          bool IsBoundary,
          bool is_all_broadcast>
struct BroadcastDataLoader {
  __device__ __forceinline__ void operator()(
      T args[Arity][VecSize],
      const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
      const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
      const phi::Array<int, Arity> &use_broadcast,
      const int block_offset,
      const int num,
      const uint32_t numel) {
#pragma unroll
    for (int i = 0; i < Arity; ++i) {
      kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
      if (use_broadcast[i]) {
        kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
            args[i], ins[i], block_offset, configs[i], numel, VecSize);
      } else {
        kps::ReadData<T, VecSize, 1, IsBoundary>(
            args[i], ins[i] + block_offset, num, VecSize);
      }
    }
283
  }
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
};

template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, true> {
  __device__ __forceinline__ void operator()(
      T args[Arity][VecSize],
      const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
      const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
      const phi::Array<int, Arity> &use_broadcast,
      const int block_offset,
      const int num,
      const uint32_t numel) {
    uint32_t index_bc[Arity][VecSize];
#pragma unroll
    for (int j = 0; j < Arity; ++j) {
#pragma unroll
      for (int k = 0; k < VecSize; ++k) {
        index_bc[j][k] = 0;
        args[j][k] = static_cast<T>(1);
      }
    }

    uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
    for (int k = 0; k < VecSize; ++k) {
      uint32_t idx = thread_offset + k;
      if (IsBoundary) {
        if (idx == numel) break;
      }

#pragma unroll
      for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
        if (i == configs[0].kDims) break;
        auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
        idx = fast_divmoder.val[0];
#pragma unroll
        for (int j = 0; j < Arity; ++j) {
          index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
        }
      }
    }

#pragma unroll
    for (int j = 0; j < Arity; ++j) {
#pragma unroll
      for (int k = 0; k < VecSize; ++k) {
        args[j][k] = ins[j][index_bc[j][k]];
      }
    }
  }
};
#endif
336

337 338 339 340 341 342
template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
343 344
          bool IsBoundary,
          bool IsAllBroadcast = false>
345
__device__ void VectorizedBroadcastKernelImpl(
346 347 348
    const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    const phi::Array<int, Arity> &use_broadcast,
349
    const uint32_t numel,
350
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
351 352
    int num,
    int block_offset,
353
    int read_lens,
354
    Functor func) {
355 356
  __simd__ InT args[Arity][VecSize];
  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
357
#ifdef PADDLE_WITH_XPU_KP
358
#pragma unroll
359
  for (int i = 0; i < Arity; ++i) {
360
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
361 362 363 364 365 366 367
    if (use_broadcast[i]) {
      kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
          args[i], ins[i], block_offset, configs[i], numel, read_lens);
    } else {
      kps::ReadData<InT, VecSize, 1, IsBoundary>(
          args[i], ins[i] + block_offset, num, read_lens);
    }
368
  }
369 370 371 372 373
#else
  BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, IsAllBroadcast>()(
      args, ins, configs, use_broadcast, block_offset, num, numel);
#endif

374 375
  constexpr bool kCallElementwiseAny =
      paddle::platform::FunctionTraits<Functor>::has_pointer_args;
376 377 378 379 380 381
  phi::funcs::ElementwisePrimitiveCaller<InT,
                                         ConditionalT<OutT, NumOuts>,
                                         VecSize,
                                         Functor,
                                         Arity,
                                         kCallElementwiseAny>()(
382 383 384 385
      func, args, result, read_lens);
  phi::funcs::
      ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
          outs, result, block_offset, num, read_lens);
386 387
}

388 389
template <typename Functor,
          typename InT,
390 391 392
          typename OutT,
          int Arity,
          int NumOuts,
393 394
          int VecSize,
          bool IsAllBroadcast>
395
__global__ void VectorizedBroadcastKernel(
396 397 398
    phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    phi::Array<int, Arity> use_broadcast,
399
    uint32_t numel,
400
    phi::Array<kps::details::BroadcastConfig, Arity> configs,
401 402
    int main_offset,
    int tail_tid,
403
    int read_lens,
404
    Functor func) {
405
#ifdef PADDLE_WITH_XPU_KP
406 407
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
  int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
408 409 410 411 412 413 414
  for (; block_offset < main_offset; block_offset += stride) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
415 416 417 418 419 420 421 422 423 424
                                  false,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  BLOCK_NUM_X * read_lens,
                                                  block_offset,
                                                  read_lens,
                                                  func);
425 426 427 428 429 430 431 432 433
  }
  int num = numel - block_offset;
  if (num > 0) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
434 435 436 437 438 439 440 441 442 443
                                  true,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  num,
                                                  block_offset,
                                                  read_lens,
                                                  func);
444 445
  }
#else
446
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
447 448 449 450 451 452 453
  if (block_offset < main_offset) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
454 455 456 457 458 459 460 461 462 463
                                  false,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  BLOCK_NUM_X * VecSize,
                                                  block_offset,
                                                  read_lens,
                                                  func);
464 465 466 467 468 469 470
  } else {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
471 472 473 474 475 476 477 478 479 480
                                  true,
                                  IsAllBroadcast>(ins,
                                                  outs,
                                                  use_broadcast,
                                                  numel,
                                                  configs,
                                                  tail_tid,
                                                  block_offset,
                                                  read_lens,
                                                  func);
481 482 483 484 485 486 487 488 489
  }
#endif
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
490 491 492 493 494 495 496
          int VecSize>
void LaunchBroadcastKernel(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    Functor func,
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
497
  int broadcast_num = 0;
498
  int numel = (*outs)[0]->numel();
499 500 501
  phi::Array<int, Arity> use_broadcast;
  phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
502 503

  for (int i = 0; i < NumOuts; ++i) {
504
    outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
505 506
  }

507
  for (int i = 0; i < Arity; ++i) {
508 509 510 511 512 513
    if (ins[i]->numel() != numel) {
      broadcast_num++;
      use_broadcast[i] = true;
    } else {
      use_broadcast[i] = false;
    }
514
    ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
515 516
  }

517
#ifdef PADDLE_WITH_XPU_KP
518 519
  const int threads = 64;
  const int blocks = 8;
520
  int read_lens = configs[0].buf_len;
521
  auto stream = ctx.x_context()->xpu_stream;
522 523
  int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
  int tail_tid = numel % (read_lens * threads);
524 525 526 527 528 529 530 531 532 533 534

  VectorizedBroadcastKernel<Functor, InT, OutT, Arity, NumOuts, VecSize, false>
      <<<blocks, threads, 0, stream>>>(ins_data,
                                       outs_data,
                                       use_broadcast,
                                       numel,
                                       configs,
                                       main_offset,
                                       tail_tid,
                                       read_lens,
                                       func);
535
#else
536 537 538
  auto gpu_config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
  int read_lens = VecSize;
539
  auto stream = ctx.stream();
540 541 542 543 544
  auto threads = gpu_config.thread_per_block;
  auto blocks = gpu_config.block_per_grid;
  int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
                    read_lens * gpu_config.GetBlockSize();
  int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

  if (broadcast_num > (Arity >> 1)) {
    VectorizedBroadcastKernel<Functor,
                              InT,
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
                              (Arity > 1)>
        <<<blocks, threads, 0, stream>>>(ins_data,
                                         outs_data,
                                         use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         read_lens,
                                         func);
  } else {
    VectorizedBroadcastKernel<Functor,
                              InT,
                              OutT,
                              Arity,
                              NumOuts,
                              VecSize,
                              false>
        <<<blocks, threads, 0, stream>>>(ins_data,
                                         outs_data,
                                         use_broadcast,
                                         numel,
                                         configs,
                                         main_offset,
                                         tail_tid,
                                         read_lens,
                                         func);
  }
581
#endif
582 583
}

584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
#ifndef PADDLE_WITH_XPU_KP
HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx(
    int64_t src_idx,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank) {
  int64_t dst_idx = 0;
  int64_t old_src_idx = src_idx;
  for (int k = 0; k < rank; ++k) {
    auto local_idx = src_idx / src_strides[k + 1];
    src_idx -= local_idx * src_strides[k + 1];

    if (dst_strides[k] != dst_strides[k + 1]) {
      dst_idx += local_idx * dst_strides[k + 1];
    }
  }
  return dst_idx;
}

template <typename T, int VecSize, bool IsBoundary>
HOSTDEVICE static void ReadVecDataWithInt64Index(
    const T *in,
    int64_t idx,
    bool need_broadcast,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &src_strides,
    const phi::Array<int64_t, phi::DDim::kMaxRank + 1> &dst_strides,
    int rank,
    int n,
    phi::AlignedVector<T, VecSize> *out) {
  if (IsBoundary) {
    for (int i = 0; i < n; ++i) {
      (*out)[i] =
          in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
    }
  } else {
    if (!need_broadcast) {
      phi::Load<T, VecSize>(in + idx, out);
    } else {
#pragma unroll
      for (int i = 0; i < VecSize; ++i) {
        (*out)[i] =
            in[ConvertSrcIdxToDstIdx(idx + i, src_strides, dst_strides, rank)];
      }
    }
  }
}

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          int NumIns>
struct ApplyFunctorWithInt64IndexHelper {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i);
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 0> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor());
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 1> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor(ins_vec[0][i]));
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 2> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(functor(ins_vec[0][i], ins_vec[1][i]));
  }
};

template <typename InT, typename OutT, typename Functor, int VecSize>
struct ApplyFunctorWithInt64IndexHelper<InT, OutT, Functor, VecSize, 3> {
  HOSTDEVICE static OutT Run(const phi::AlignedVector<InT, VecSize> *ins_vec,
                             Functor functor,
                             int i) {
    return static_cast<OutT>(
        functor(ins_vec[0][i], ins_vec[1][i], ins_vec[2][i]));
  }
};

template <int N>
struct MaxWithOne {
  static constexpr auto kValue = (N >= 1 ? N : 1);
};

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          int NumIns>
__global__ void BroadcastKernelWithInt64Index(
    phi::Array<const InT *, MaxWithOne<NumIns>::kValue> ins,
    OutT *out,
    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<NumIns>::kValue> ins_strides,
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides,
    phi::Array<bool, MaxWithOne<NumIns>::kValue> need_broadcasts,
    int rank,
    Functor functor) {
  int64_t numel = out_strides[0];
  int64_t idx =
      (static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize;
  int64_t stride = static_cast<int64_t>(blockDim.x) * gridDim.x * VecSize;
  int64_t limit = numel - VecSize;

  phi::Array<phi::AlignedVector<InT, VecSize>, MaxWithOne<NumIns>::kValue>
      ins_vec;
  phi::AlignedVector<OutT, VecSize> out_vec;
  for (; idx <= limit; idx += stride) {
#pragma unroll
    for (int i = 0; i < NumIns; ++i) {
      ReadVecDataWithInt64Index<InT, VecSize, false>(ins[i],
                                                     idx,
                                                     need_broadcasts[i],
                                                     out_strides,
                                                     ins_strides[i],
                                                     rank,
                                                     VecSize,
                                                     &ins_vec[i]);
    }

#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
      out_vec[i] = ApplyFunctorWithInt64IndexHelper<InT,
                                                    OutT,
                                                    Functor,
                                                    VecSize,
                                                    NumIns>::Run(ins_vec.Get(),
                                                                 functor,
                                                                 i);
    }

    phi::Store<OutT, VecSize>(out_vec, out + idx);
  }

  if (idx < numel) {
    int remain = numel - idx;  // remain is always less than VecSize, therefore
                               // `int` is enough here
#pragma unroll
    for (int i = 0; i < NumIns; ++i) {
      ReadVecDataWithInt64Index<InT, VecSize, true>(ins[i],
                                                    idx,
                                                    need_broadcasts[i],
                                                    out_strides,
                                                    ins_strides[i],
                                                    rank,
                                                    remain,
                                                    &ins_vec[i]);
    }

    for (int i = 0; i < remain; ++i) {
      out[idx + i] =
          ApplyFunctorWithInt64IndexHelper<InT,
                                           OutT,
                                           Functor,
                                           VecSize,
                                           NumIns>::Run(ins_vec.Get(),
                                                        functor,
                                                        i);
    }
  }
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    PADDLE_THROW(phi::errors::PermissionDenied(
        "Unreachable code branch. This may be a bug."));
  }
};

template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                 OutT,
                                                 Functor,
                                                 Arity,
                                                 /*NumOuts=*/1,
                                                 VecSize> {
  static void Run(const KPDevice &ctx,
                  const std::vector<const DenseTensor *> &ins,
                  std::vector<DenseTensor *> *outs,
                  int axis,
                  Functor functor) {
    phi::Array<const InT *, MaxWithOne<Arity>::kValue> ins_ptrs;
    for (int i = 0; i < Arity; ++i) {
      ins_ptrs[i] = ins[i]->data<InT>();
    }
    auto *out_tensor = (*outs)[0];
    auto *out_ptr = ctx.Alloc<OutT>(out_tensor);

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank>,
               MaxWithOne<Arity>::kValue>
        ins_expand_dims;
    phi::Array<int64_t, phi::DDim::kMaxRank> broadcast_out_dims;
    int rank;
    if (Arity == 1) {
      rank = ins[0]->dims().size();
      for (int i = 0; i < rank; ++i) {
        broadcast_out_dims[i] = ins[0]->dims()[i];
      }
      ins_expand_dims[0] = broadcast_out_dims;
    } else if (Arity >= 2) {
      CalculateBroadcastDims(ins[0]->dims().Get(),
                             ins[1]->dims().Get(),
                             ins[0]->dims().size(),
                             ins[1]->dims().size(),
                             axis,
                             ins_expand_dims[0].GetMutable(),
                             ins_expand_dims[1].GetMutable(),
                             broadcast_out_dims.GetMutable(),
                             &rank);
      for (int i = 2; i < Arity; ++i) {
        auto tmp_dims = broadcast_out_dims;
        phi::Array<int64_t, phi::DDim::kMaxRank> tmp_expand_dims;
        int tmp_rank;
        PADDLE_ENFORCE_GE(rank,
                          ins[i]->dims().size(),
                          phi::errors::InvalidArgument(
                              "Unsupported reverse broadcast when the input "
                              "tensor number is larger than 2."));
        CalculateBroadcastDims(tmp_dims.Get(),
                               ins[i]->dims().Get(),
                               rank,
                               ins[i]->dims().size(),
                               axis,
                               tmp_expand_dims.GetMutable(),
                               ins_expand_dims[i].GetMutable(),
                               broadcast_out_dims.GetMutable(),
                               &tmp_rank);
        PADDLE_ENFORCE_EQ(rank,
                          tmp_rank,
                          phi::errors::InvalidArgument(
                              "Wrong broadcast algorithm. This may be a bug."));
      }
    }

    phi::Array<phi::Array<int64_t, phi::DDim::kMaxRank + 1>,
               MaxWithOne<Arity>::kValue>
        ins_strides;
    phi::Array<bool, MaxWithOne<Arity>::kValue> need_broadcasts;
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> out_strides;
    const auto &out_dims = out_tensor->dims();
    if (rank <= out_dims.size()) {
      out_strides = ShapeToStride(out_dims.Get(), rank);
    } else {
      out_strides = ShapeToStride(broadcast_out_dims.Get(), rank);
    }

    for (int i = 0; i < Arity; ++i) {
      ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank);
      need_broadcasts[i] =
          !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1);
    }

    int64_t numel = out_strides[0];
    auto gpu_config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);

    BroadcastKernelWithInt64Index<InT, OutT, Functor, VecSize, Arity>
        <<<gpu_config.block_per_grid,
           gpu_config.thread_per_block,
           0,
           ctx.stream()>>>(ins_ptrs,
                           out_ptr,
                           ins_strides,
                           out_strides,
                           need_broadcasts,
                           rank,
                           functor);
  }

 private:
  static void CalculateBroadcastDims(const int64_t *x_dims,
                                     const int64_t *y_dims,
                                     int nx,
                                     int ny,
                                     int axis,
                                     int64_t *x_out_dims,
                                     int64_t *y_out_dims,
                                     int64_t *broadcast_out_dims,
                                     int *length) {
    PADDLE_ENFORCE_GE(
        axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis));
    if (nx == ny) {
      *length = nx;
      for (int i = 0; i < nx; ++i) {
        if (x_dims[i] != y_dims[i]) {
          PADDLE_ENFORCE_EQ(
              x_dims[i] == 1 || y_dims[i] == 1,
              true,
              phi::errors::InvalidArgument("Cannot broadcast input shape where "
                                           "x_dims[%d] = %d, y_dims[%d] = %d.",
                                           i,
                                           x_dims[i],
                                           i,
                                           y_dims[i]));
        }
        broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]);
        x_out_dims[i] = x_dims[i];
        y_out_dims[i] = y_dims[i];
      }
    } else if (nx > ny) {
      *length = nx;
      for (int i = nx - axis; i < ny; ++i) {
        PADDLE_ENFORCE_EQ(
            y_dims[i],
            1,
            phi::errors::InvalidArgument(
                "The trailing Y.shape[%d] should be 1 but got %d.",
                i,
                y_dims[i]));
      }

      for (int i = 0; i < nx; ++i) {
        if (i >= axis && i - axis < ny) {
          if (x_dims[i] != y_dims[i - axis]) {
            PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1,
                              true,
                              phi::errors::InvalidArgument(
                                  "Cannot broadcast input shape where "
                                  "x_dims[%d] = %d, y_dims[%d] = %d.",
                                  i,
                                  x_dims[i],
                                  i - axis,
                                  y_dims[i - axis]));
          }
          broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]);
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = y_dims[i - axis];
        } else {
          broadcast_out_dims[i] = x_dims[i];
          x_out_dims[i] = x_dims[i];
          y_out_dims[i] = 1;
        }
      }
    } else {
      CalculateBroadcastDims(y_dims,
                             x_dims,
                             ny,
                             nx,
                             axis,
                             y_out_dims,
                             x_out_dims,
                             broadcast_out_dims,
                             length);
    }
  }

  static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) {
    for (int i = 0; i < rank; ++i) {
      if (x[i] != y[i]) return false;
    }
    return true;
  }

  static phi::Array<int64_t, phi::DDim::kMaxRank + 1> ShapeToStride(
      const int64_t *arr, int rank) {
    phi::Array<int64_t, phi::DDim::kMaxRank + 1> strides;
    strides[rank] = 1;
    for (int i = rank - 1; i >= 0; --i) {
      strides[i] = strides[i + 1] * arr[i];
    }
    return strides;
  }
};
#endif

975 976 977 978 979 980 981 982 983 984 985 986 987 988
template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernelForDifferentVecSize(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
  using Traits = paddle::platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
  PADDLE_ENFORCE_EQ(
      ins.size(),
      kArity,
      phi::errors::InvalidArgument("The number of inputs is expected to be "
                                   "equal to the "
                                   "arity of functor. But recieved: the "
                                   "number of inputs "
                                   "is %d, the arity of functor is %d.",
                                   ins.size(),
                                   kArity));
  PADDLE_ENFORCE_LE(
      kArity,
      3,
      phi::errors::InvalidArgument("Currently only broadcast of ternary is "
                                   "supported "
                                   "and verified, but received %d.",
                                   kArity));
  PADDLE_ENFORCE_EQ(
      outs->size(),
      NumOuts,
      phi::errors::InvalidArgument("Number of outputs shall equal to number "
                                   "of functions, "
                                   "but number of outputs is %d, of "
                                   "functions is %d.",
                                   outs->size(),
                                   NumOuts));
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072

#ifndef PADDLE_WITH_XPU_KP
  constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3);
  bool use_int64_index_kernel =
      kEnabledInt64IndexKernel &&
      (*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
  if (use_int64_index_kernel) {
    int vec_size = GetVecsize<InT, OutT>(ins, outs);
    switch (vec_size) {
      case VecSizeL: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeL>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeM: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeM>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      case VecSizeS: {
        LaunchBroadcastKernelWithInt64IndexHelper<InT,
                                                  OutT,
                                                  Functor,
                                                  kArity,
                                                  NumOuts,
                                                  VecSizeS>::Run(ctx,
                                                                 ins,
                                                                 outs,
                                                                 axis,
                                                                 func);
        break;
      }
      default: {
        PADDLE_THROW(phi::errors::Unimplemented(
            "Unsupported vectorized size: %d!", vec_size));
        break;
      }
    }
    return;
  }
#endif

1073 1074 1075
  // mergedim and get vec_size
  const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
  phi::Array<kps::details::BroadcastConfig, kArity> configs;
1076

1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
// get vec_size
#ifdef PADDLE_WITH_XPU_KP
  PADDLE_ENFORCE_EQ(
      ins.size(),
      2,
      phi::errors::InvalidArgument(
          "XPU only support inputs is 2, but received %d", ins.size()));
  configs[0] = kps::details::BroadcastConfig(merge_dims.out_dims,
                                             merge_dims.in_dims[0],
                                             merge_dims.in_dims[1],
                                             merge_dims.dim_size);
  configs[1] = kps::details::BroadcastConfig(merge_dims.out_dims,
                                             merge_dims.in_dims[1],
                                             merge_dims.in_dims[0],
                                             merge_dims.dim_size);
  auto type = kps::details::OptType::CanNotOptimize;
  bool is_optimize = configs[0].cmp_type != type;
  int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
1096
  for (int i = 0; i < kArity; ++i) {
1097 1098 1099
    // get the broadcast config,
    // if data shape is[m, n], then you should set data_dim = {n, m}
    // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
1100
    // if (ins[i]->numel() != (*outs)[0]->numel()) {
1101 1102 1103 1104
    if (ins[i]->numel()) {
      configs[i] = kps::details::BroadcastConfig(
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
    }
1105
  }
1106
  int vec_size = GetVecsize<InT, OutT>(ins, outs);
1107
#endif
1108 1109

  switch (vec_size) {
1110 1111 1112
    case VecSizeL: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>(
          ctx, ins, outs, func, configs);
1113 1114
      break;
    }
1115 1116 1117
    case VecSizeM: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>(
          ctx, ins, outs, func, configs);
1118 1119
      break;
    }
1120 1121 1122
    case VecSizeS: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>(
          ctx, ins, outs, func, configs);
1123 1124 1125
      break;
    }
    default: {
1126
      PADDLE_THROW(phi::errors::Unimplemented(
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143
          "Unsupported vectorized size: %d!", vec_size));
      break;
    }
  }
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx,
                     const std::vector<const DenseTensor *> &ins,
                     std::vector<DenseTensor *> *outs,
                     int axis,
                     Functor func) {
  std::vector<int> dims_size;
1144
  dims_size.reserve(ins.size());
1145 1146 1147 1148
  for (auto *in : ins) {
    dims_size.emplace_back(in->dims().size());
  }

1149 1150 1151
  axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) -
                          *std::min_element(dims_size.begin(), dims_size.end())
                    : axis;
1152 1153
  BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
      ctx, ins, outs, axis, func);
1154 1155
}

1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
                        const DenseTensor &x,
                        const DenseTensor &y,
                        int axis,
                        Functor func,
                        DenseTensor *z) {
  std::vector<const DenseTensor *> ins = {&x, &y};
  std::vector<DenseTensor *> outs = {z};
  z->mutable_data<OutType>(dev_ctx.GetPlace());
  BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
      dev_ctx, ins, &outs, axis, func);
}

1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
}

#else
1186

1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  if (x_dims.size() >= y_dims.size()) {
    funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
  } else {
    funcs::ElementwiseCompute<InverseFunctor, T>(
        dev_ctx, x, y, axis, InverseFunctor(), z);
  }
}

1207 1208
#endif

1209
}  // namespace funcs
1210
}  // namespace phi