elementwise_base.h 25.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/fluid/platform/for_range.h"
18
#include "paddle/fluid/platform/transform.h"
19 20 21 22
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
23

24 25 26
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
27 28
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
29

30
namespace kps = phi::kps;
31 32 33

#endif

34
namespace phi {
35

36 37 38 39 40
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
   for supporting multiple-output feature in elementwise system.*/
template <class T, int Num>
using ConditionalT =
41
    typename std::conditional_t<Num == 1, T, phi::Array<T, Num>>;
42 43

namespace funcs {
44
using DDim = phi::DDim;
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast {
  const T *x_;
  const T *y_;
  const Tout *out_;
  const Tout *dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T *dx_;
  T *dy_;
};

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;

template <typename T, typename DeviceContext>
class MidWiseTransformIterator;

// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T>
class RowwiseTransformIterator<T, CPUContext>
    : public std::iterator<std::random_access_iterator_tag,
                           T,
                           std::ptrdiff_t,
                           T *,
                           T &> {
 public:
  RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

  RowwiseTransformIterator<T, CPUContext> &operator++() {
    ++i_;
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
    return *this;
  }

  RowwiseTransformIterator<T, CPUContext> &operator+(int n) {
    while (n-- > 0) {
      ++i_;
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
    }

    return *this;
  }

  bool operator==(const RowwiseTransformIterator<T, CPUContext> &rhs) const {
    return (ptr_ + i_) == &(*rhs);
  }

  bool operator!=(const RowwiseTransformIterator<T, CPUContext> &rhs) const {
    return (ptr_ + i_) != &(*rhs);
  }

  const T &operator*() { return ptr_[i_]; }

 private:
  const T *ptr_;
  int i_;
  int64_t n_;
};

template <typename T>
class MidWiseTransformIterator<T, CPUContext>
    : public std::iterator<std::random_access_iterator_tag,
                           T,
                           std::ptrdiff_t,
                           T *,
                           T &> {
 public:
  MidWiseTransformIterator(const T *ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

  MidWiseTransformIterator<T, CPUContext> &operator++() {
    ++j_;
    if (UNLIKELY(j_ == post_)) {
      ++i_;
      j_ = 0;
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
    }
    return *this;
  }

  MidWiseTransformIterator<T, CPUContext> &operator+(int n) {
    while (n-- > 0) {
      ++j_;
      if (UNLIKELY(j_ == post_)) {
        ++i_;
        j_ = 0;
        if (UNLIKELY(i_ == n_)) {
          i_ = 0;
        }
      }
    }
    return *this;
  }

  bool operator==(const MidWiseTransformIterator<T, CPUContext> &rhs) const {
    return (ptr_ + i_) == &(*rhs);
  }

  bool operator!=(const MidWiseTransformIterator<T, CPUContext> &rhs) const {
    return (ptr_ + i_) != &(*rhs);
  }

  const T &operator*() { return ptr_[i_]; }

 private:
  const T *ptr_;
  int64_t i_;
  int64_t j_;
  int64_t n_;
  int64_t post_;
};

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
177 178
class RowwiseTransformIterator<T, GPUContext>
    : public thrust::iterator_adaptor<RowwiseTransformIterator<T, GPUContext>,
179 180
                                      const T *> {
 public:
181
  typedef thrust::iterator_adaptor<RowwiseTransformIterator<T, GPUContext>,
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
                                   const T *>
      super_t;
  HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
      : super_t(x), begin_(x), n_(n) {}
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T *begin_;
  HOSTDEVICE typename super_t::reference dereference() const {
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
197 198
class MidWiseTransformIterator<T, GPUContext>
    : public thrust::iterator_adaptor<MidWiseTransformIterator<T, GPUContext>,
199 200
                                      const T *> {
 public:
201
  typedef thrust::iterator_adaptor<MidWiseTransformIterator<T, GPUContext>,
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
                                   const T *>
      super_t;
  HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
      : super_t(x), begin_(x), n_(n), post_(post) {}
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T *begin_;
  HOSTDEVICE typename super_t::reference dereference() const {
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

template <typename Functor,
          typename T,
          typename DeviceContext,
          typename OutType = T>
class TransformFunctor {
 public:
  TransformFunctor(const DenseTensor &x,
                   const DenseTensor &y,
                   DenseTensor *z,
                   const DeviceContext &ctx,
                   Functor func,
                   const bool is_xsize_larger = true)
      : x_(x.data<T>()),
        y_(y.data<T>()),
232
        z_(ctx.template Alloc<OutType>(z)),
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        nx_(x.numel()),
        ctx_(ctx),
        func_(func),
        is_xsize_larger_(is_xsize_larger) {
    if (is_xsize_larger_ == false) {
      nx_ = y.numel();
    }
  }

  inline void Run() const {
    paddle::platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
  }

  inline void RunRowWise(int n, int pre) const {
    paddle::platform::Transform<DeviceContext> trans;
    if (is_xsize_larger_) {
      trans(ctx_,
            x_,
            x_ + nx_,
            RowwiseTransformIterator<T, DeviceContext>(y_, n),
            z_,
            func_);
    } else {
      trans(ctx_,
            y_,
            y_ + nx_,
            RowwiseTransformIterator<T, DeviceContext>(x_, n),
            z_,
            func_);
    }
  }

  inline void RunMidWise(int n, int pre, int post) const {
    paddle::platform::Transform<DeviceContext> trans;
    if (is_xsize_larger_) {
      trans(ctx_,
            x_,
            x_ + nx_,
            MidWiseTransformIterator<T, DeviceContext>(y_, n, post),
            z_,
            func_);
    } else {
      trans(ctx_,
            y_,
            y_ + nx_,
            MidWiseTransformIterator<T, DeviceContext>(x_, n, post),
            z_,
            func_);
    }
  }

 private:
  const T *x_;
  const T *y_;
  OutType *z_;
  int64_t nx_;
  const DeviceContext &ctx_;
  Functor func_;
  bool is_xsize_larger_;
};

inline DDim trim_trailing_singular_dims(const DDim &dims) {
  // Remove trailing dimensions of size 1 for y
  auto actual_dims_size = dims.size();
  for (; actual_dims_size != 0; --actual_dims_size) {
    if (dims[actual_dims_size - 1] != 1) break;
  }
  if (actual_dims_size == dims.size()) return dims;
  std::vector<int> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
  }
  if (trim_dims.size() == 0) {
308
    return DDim(phi::make_dim());
309
  }
310
  DDim actual_dims = phi::make_ddim(trim_dims);
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
  return actual_dims;
}

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
 *
 * New parameter: *is_run_common_broadcast* is a flag to record whether to run
 * common broadcast code.
 */
inline void get_mid_dims(const DDim &x_dims,
                         const DDim &y_dims,
                         const int axis,
                         int *pre,
                         int *n,
                         int *post,
                         int *is_run_common_broadcast) {
  *pre = 1;
  *n = 1;
  *post = 1;
  *is_run_common_broadcast = 0;
  for (int i = 0; i < axis; ++i) {
    (*pre) *= x_dims[i];
  }
  for (int i = 0; i < y_dims.size(); ++i) {
    if (x_dims[i + axis] != y_dims[i]) {
      PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1,
                        true,
346
                        phi::errors::InvalidArgument(
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
                            "Broadcast dimension mismatch. Operands "
                            "could not be broadcast together with the shape of "
                            "X = [%s] and the shape of Y = [%s]. Received [%d] "
                            "in X is not equal to [%d] in Y.",
                            x_dims,
                            y_dims,
                            x_dims[i + axis],
                            y_dims[i]));
      *is_run_common_broadcast = 1;
      return;
    }
    (*n) *= y_dims[i];
  }
  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
    (*post) *= x_dims[i];
  }
}

365
// for broadcast backwards
366 367
static inline std::vector<int> GetReduceDim(const DDim &in,
                                            const DDim &out,
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
                                            int axis) {
  axis =
      (axis == -1 ? std::abs(static_cast<int>(out.size() - in.size())) : axis);
  std::vector<int> dims;
  for (int i = 0; i < axis; ++i) {
    dims.push_back(i);
  }
  for (int i = 0; i < in.size(); ++i) {
    if (out[i + axis] != in[i]) {
      dims.push_back(i + axis);
    }
  }
  for (int i = axis + in.size(); i < out.size(); ++i) {
    dims.push_back(i);
  }
  return dims;
}

template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx,
                                           const DenseTensor &x,
                                           const DenseTensor *ddx,
                                           DenseTensor *ddx_safe) {
  if (ddx) {
    *ddx_safe = *ddx;
  } else {
394 395
    auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout());
    *ddx_safe = phi::Empty(dev_ctx, std::move(meta));
396
    ddx_safe->mutable_data(dev_ctx.GetPlace());
397
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
398 399 400 401
    set_zero(dev_ctx, ddx_safe, static_cast<T>(0));
  }
}

402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
template <typename DeviceContext,
          typename T,
          typename DX_OP,
          typename DY_OP,
          typename Tout = T>
void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
                                    const DDim &x_dim,
                                    const DDim &y_dim,
                                    const DenseTensor &x,
                                    const DenseTensor &y,
                                    const DenseTensor &out,
                                    const DenseTensor &dout,
                                    int axis,
                                    DenseTensor *dx,
                                    DenseTensor *dy,
                                    DX_OP dx_op,
                                    DY_OP dy_op) {
419
  size_t N = static_cast<size_t>(phi::product(x_dim));
420 421 422 423 424 425 426 427
  paddle::platform::ForRange<DeviceContext> for_range(dev_ctx, N);
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
      x.data<T>(),
      y.data<T>(),
      out.data<Tout>(),
      dout.data<Tout>(),
      dx_op,
      dy_op,
428 429
      dx == nullptr ? nullptr : dev_ctx.template Alloc<T>(dx),
      dy == nullptr ? nullptr : dev_ctx.template Alloc<T>(dy)});
430 431
}

432 433 434 435 436 437 438
inline void ElementwiseGradPreProcess(const DenseTensor &dout,
                                      DenseTensor *dx) {
  if (dx != nullptr) {
    dx->set_lod(dout.lod());
  }
}

439 440
#if defined(__NVCC__) || defined(__HIPCC__)

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
// static unroller
template <template <int Index, int VecSize> typename Func,
          int VecSize,
          int End,
          int Begin = 0>
struct Unroller {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&... args) {
    Func<Begin, VecSize>::Apply(std::forward<Args>(args)...);
    Unroller<Func, VecSize, End, Begin + 1>::step(args...);
  }
};

template <template <int Index, int VecSize> typename Func, int VecSize, int End>
struct Unroller<Func, VecSize, End, End> {
  template <typename... Args>
  static HOSTDEVICE inline void step(Args &&... args) {}
};

template <int Index, int VecSize>
struct Loader {
  template <typename Array, typename ArgsT>
  static __device__ void Apply(const Array &in,
                               ArgsT *args,
                               int num,
                               int data_offset,
                               bool is_boundary) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    kps::Init<Type, ArgsT, Index, VecSize>(args, static_cast<Type>(1.0f));
    if (is_boundary) {
      kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, true>(
          args, reinterpret_cast<const Type *>(in[Index]) + data_offset, num);
    } else {
      kps::ReadData<Type, VecSize, 1, 1, ArgsT, Index, false>(
          args, reinterpret_cast<const Type *>(in[Index]) + data_offset, num);
    }
  }
};

template <int Index, int VecSize>
struct InputSetter {
  template <typename Array>
  static HOSTDEVICE void Apply(
      const std::vector<const DenseTensor *> &ins_tensor, Array *ins_data) {
    (*ins_data)[Index] =
        reinterpret_cast<const _ptr_ char *>(ins_tensor[Index]->data());
  }
};

template <int Index, int VecSize>
struct VecSizeGetter {
  template <typename ArgsT>
  static HOSTDEVICE void Apply(const std::vector<const DenseTensor *> &ins,
                               const ArgsT &args,
                               int *vec_size) {
    using Type = std::tuple_element_t<Index, ArgsT>;
    *vec_size = std::min<int>(
        *vec_size,
        paddle::platform::GetVectorizedSize(ins[Index]->data<Type>()));
  }
};

template <typename OutT, typename Functor>
504 505
int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
                                const std::vector<DenseTensor *> &outs) {
506 507 508
  using Traits = paddle::platform::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;
  const int Arity = Traits::arity;
509
  int vec_size = 4;
510 511 512
  ArgsT arg;
  // The Arg VecSize=1 is to match the Unroller template.
  Unroller<VecSizeGetter, 1, Arity>::step(ins, arg, &vec_size);
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
  for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
    vec_size = std::min<int>(
        vec_size, paddle::platform::GetVectorizedSize((*iter)->data<OutT>()));
  }
  return vec_size;
}

template <typename InT,
          typename OutT,
          int VecSize,
          typename Functor,
          int Arity,
          bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result);
};

template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
        result, args, func);
  }
};

542 543 544 545 546
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
547
    kps::ElementwiseConstant<InT, OutT, VecSize, 1, 1, Functor>(result, func);
548 549 550
  }
};

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], func);
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], func);
  }
};

template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
  __device__ inline void operator()(Functor func,
                                    InT (*args)[VecSize],
                                    OutT *result) {
    kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
        result, args[0], args[1], args[2], func);
  }
};

581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
namespace detail {
template <class F, class Tuple, std::size_t... Index>
// GCC/Clang need the decltype() return type
HOSTDEVICE constexpr decltype(auto) ApplyImpl(F &&f,
                                              Tuple &&t,
                                              std::index_sequence<Index...>) {
  return std::forward<F>(f)(std::get<Index>(std::forward<Tuple>(t))...);
}
}  // namespace detail

template <class F, class Tuple>
HOSTDEVICE constexpr decltype(auto) Apply(F &&f, Tuple &&t) {
  return detail::ApplyImpl(
      std::forward<F>(f),
      std::forward<Tuple>(t),
      std::make_index_sequence<
          std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
}

template <typename OutT,
          int VecSize,
          typename Functor,
          typename ArgsT,
          int Arity>
struct SameDimsElementwisePrimitiveCaller {
  __device__ inline void operator()(Functor func, ArgsT *args, OutT *result) {
#pragma unroll
    for (int idx = 0; idx < VecSize; ++idx) {
      result[idx] = static_cast<OutT>(Apply(func, args[idx]));
    }
  }
};

614 615 616
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
struct ElementwiseWriteDataCaller {
  __device__ __forceinline__ void operator()(
617
      phi::Array<_ptr_ OutT *, NumOuts> outs,
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638
      ConditionalT<OutT, NumOuts> src[VecSize],
      int block_offset,
      int num) {
    OutT dst[NumOuts][VecSize];
#pragma unroll
    for (int i = 0; i < VecSize; ++i) {
#pragma unroll
      for (int j = 0; j < NumOuts; ++j) {
        dst[j][i] = (src[i])[j];
      }
    }
#pragma unroll
    for (int i = 0; i < NumOuts; ++i) {
      kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
          outs[i] + block_offset, dst[i], num);
    }
  }
};

template <typename OutT, int VecSize, bool IsBoundary>
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
639
  __device__ __forceinline__ void operator()(phi::Array<_ptr_ OutT *, 1> outs,
640 641 642
                                             OutT src[VecSize],
                                             int block_offset,
                                             int num) {
643 644 645 646 647
    kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
        outs[0] + block_offset, src, num);
  }
};

648
template <typename OutT,
649 650 651 652 653 654
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
          bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(
655

656 657
    const phi::Array<const _ptr_ char *__restrict__, Arity> &in,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
658 659 660
    int num,
    int data_offset,
    Functor func) {
661 662 663
  using Traits = paddle::platform::FunctionTraits<Functor>;
  using ArgsT = typename Traits::ArgsTuple;
  ArgsT args[VecSize];
664 665
  ConditionalT<OutT, NumOuts> result[VecSize];

666 667
  Unroller<Loader, VecSize, Arity>::step(
      in, args, num, data_offset, IsBoundary);
668

669 670 671 672 673
  SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
                                     VecSize,
                                     Functor,
                                     ArgsT,
                                     Arity>()(func, args, result);
674 675 676 677 678

  ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
      outs, result, data_offset, num);
}

679
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
680
__global__ void VectorizedElementwiseKernel(
681 682
    phi::Array<const _ptr_ char *__restrict__, Arity> ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
683 684 685 686 687 688
    int size,
    int main_offset,
    Functor func) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  for (; data_offset < main_offset; data_offset += stride) {
689
    VectorizedElementwiseKernelImpl<OutT,
690 691 692 693 694 695 696 697 698 699
                                    Functor,
                                    Arity,
                                    NumOuts,
                                    VecSize,
                                    false>(
        ins, outs, VecSize * BLOCK_NUM_X, data_offset, func);
  }

  int num = size - data_offset;
  if (num > 0) {
700
    VectorizedElementwiseKernelImpl<OutT,
701 702 703 704 705 706 707 708
                                    Functor,
                                    Arity,
                                    NumOuts,
                                    VecSize,
                                    true>(ins, outs, num, data_offset, func);
  }
}

709
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
710 711 712 713
void ElementwiseCudaKernel(const KPDevice &ctx,
                           const std::vector<const DenseTensor *> &ins,
                           std::vector<DenseTensor *> *outs,
                           Functor func) {
714 715
  auto numel =
      (*outs)[0]->numel();  // To avoid running errors when ins.size()== 0
716 717
  phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
718

719
  Unroller<InputSetter, VecSize, Arity>::step(ins, &ins_data);
720
  for (int i = 0; i < NumOuts; ++i) {
721
    outs_data[i] = ctx.Alloc<OutT>((*outs)[i]);
722 723 724 725 726 727
  }
#ifdef PADDLE_WITH_XPU2
  int block_size = 64;
  int grid_size = 8;
  auto stream = ctx.x_context()->xpu_stream;
  int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size;
728
  VectorizedElementwiseKernel<OutT,
729 730 731 732 733 734
                              Functor,
                              Arity,
                              NumOuts,
                              VecSize><<<grid_size, block_size, 0, stream>>>(
      ins_data, outs_data, numel, main_offset, func);
#else
W
Wilber 已提交
735
  auto gpu_config =
736
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
737 738 739
  int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize *
                    gpu_config.GetBlockSize();
  auto stream = ctx.stream();
740
  VectorizedElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSize><<<
741 742 743 744 745 746 747
      gpu_config.block_per_grid,
      gpu_config.thread_per_block,
      0,
      stream>>>(ins_data, outs_data, numel, main_offset, func);
#endif
}

748
template <typename OutT, typename Functor, int NumOuts = 1>
749 750 751 752
void ElementwiseKernel(const KPDevice &ctx,
                       const std::vector<const DenseTensor *> &ins,
                       std::vector<DenseTensor *> *outs,
                       Functor func) {
753
  using Traits = paddle::platform::FunctionTraits<Functor>;
754
  const int kArity = Traits::arity;
755 756
  PADDLE_ENFORCE_EQ(ins.size(),
                    kArity,
757
                    phi::errors::InvalidArgument(
758 759 760 761 762 763 764
                        "The number of inputs is expected to be equal to the "
                        "arity of functor. But recieved: the number of inputs "
                        "is %d, the arity of functor is %d.",
                        ins.size(),
                        kArity));
  PADDLE_ENFORCE_EQ(outs->size(),
                    NumOuts,
765
                    phi::errors::InvalidArgument(
766 767 768 769 770 771 772 773 774 775
                        "Number of outputs shall equal to number of functions, "
                        "but number of outputs is %d, of functions is %d.",
                        outs->size(),
                        NumOuts));

  if (NumOuts > 1) {
    for (int i = 1; i < NumOuts; ++i) {
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
776
          phi::errors::InvalidArgument(
777 778 779 780 781 782 783
              "The shape of each output tensor shall be identical yet, "
              "but %dth output tensor`s shape is not.",
              i));
    }
  }

  // calculate the max vec_size for all ins and outs
784
  int vec_size = GetVectorizedSizeForTensors<OutT, Functor>(ins, *outs);
785 786
  switch (vec_size) {
    case 4:
787
      ElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, 4>(
788 789 790
          ctx, ins, outs, func);
      break;
    case 2:
791
      ElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, 2>(
792 793 794
          ctx, ins, outs, func);
      break;
    case 1:
795
      ElementwiseCudaKernel<OutT, Functor, kArity, NumOuts, 1>(
796 797 798
          ctx, ins, outs, func);
      break;
    default: {
799
      PADDLE_THROW(phi::errors::Unimplemented(
800 801 802 803 804 805 806
          "Unsupported vectorized size: %d !", vec_size));
      break;
    }
  }
}
#endif

807
}  // namespace funcs
808
}  // namespace phi