elementwise_grad_base.h 68.7 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/backends/all_context.h"
18
#include "paddle/phi/backends/gpu/gpu_info.h"
19
#include "paddle/phi/core/dense_tensor.h"
20
#include "paddle/phi/kernels/funcs/common_shape.h"
21 22 23 24 25 26
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"

#if defined(__NVCC__) || defined(__HIPCC__)
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
27
#include "paddle/phi/backends/gpu/gpu_device_function.h"
28
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
29
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
30 31

#endif
32

33 34 35 36 37
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
#else
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
38

39 40 41
#define BLOCK_X 32
#define BLOCK_Y 32

42 43 44 45 46 47 48
#define GetDivMod(dividend, divisor, div, mod) \
  do {                                         \
    const auto dividend_copy = dividend;       \
    *div = dividend_copy / divisor;            \
    *mod = dividend_copy % divisor;            \
  } while (0)

49
namespace phi {
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
namespace funcs {
using DDim = phi::DDim;

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(const DenseTensor &x,
                            const DenseTensor &y,
                            const DenseTensor &out,
                            const DenseTensor &dout,
                            DenseTensor *dx,
                            DenseTensor *dy,
                            int *x_dims_array,
                            int *y_dims_array,
                            int *out_dims_array,
                            int max_dim,
                            const CPUContext &ctx,
                            DX_OP dx_op,
                            DY_OP dy_op) {
  std::vector<int> index_array(max_dim, 0);
  const T *x_data = x.data<T>();
  const T *y_data = y.data<T>();
  const Tout *out_data = out.data<Tout>();
  const Tout *dout_data = dout.data<Tout>();
  T *dx_data = dx == nullptr ? nullptr : ctx.Alloc<T>(dx);
  T *dy_data = dy == nullptr ? nullptr : ctx.Alloc<T>(dy);
  if (dx_data != nullptr) {
    memset(dx_data, 0, dx->numel() * sizeof(T));
  }
  if (dy_data != nullptr) {
    memset(dy_data, 0, dy->numel() * sizeof(T));
  }
  const int out_size = std::accumulate(
      out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
  int x_index, y_index;
  for (int out_index = 0; out_index < out_size; ++out_index) {
    x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
    y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
    if (dx_data != nullptr) {
      dx_data[x_index] += dx_op(x_data[x_index],
                                y_data[y_index],
                                out_data[out_index],
                                dout_data[out_index]);
    }
    if (dy_data != nullptr) {
      dy_data[y_index] += dy_op(x_data[x_index],
                                y_data[y_index],
                                out_data[out_index],
                                dout_data[out_index]);
    }

    UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast1CPU(const T *x,
                                      const T *y,
                                      const Tout *out,
                                      const Tout *dout,
                                      int h,
                                      int w,
                                      bool is_xsize_larger,
                                      DX_OP dx_op,
                                      DY_OP dy_op,
                                      T *dx,
                                      T *dy) {
  if (is_xsize_larger) {
    for (int i = 0; i < h; ++i) {
      for (int j = 0; j < w; ++j) {
        int x_offset = i * w + j;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int i = 0; i < h; ++i) {
      for (int j = 0; j < w; ++j) {
        int y_offset = i * w + j;
        if (dy != nullptr) {
          dy[y_offset] =
              dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
        }
        if (dx != nullptr) {
          T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
          if (i == 0) {
            dx[j] = tmp;
          } else {
            dx[j] += tmp;
          }
        }
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast2CPU(const T *x,
                                      const T *y,
                                      const Tout *out,
                                      const Tout *dout,
                                      int pre,
                                      int n,
                                      int post,
                                      bool is_xsize_larger,
                                      DX_OP dx_op,
                                      DY_OP dy_op,
                                      T *dx,
                                      T *dy) {
  if (is_xsize_larger) {
    for (int i = 0; i < pre; ++i) {
      for (int j = 0; j < n; ++j) {
        for (int k = 0; k < post; ++k) {
          int x_offset = i * n * post + j * post + k;
          if (dx != nullptr) {
            dx[x_offset] =
                dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          }
          if (dy != nullptr) {
            T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
            if (i == 0 && k == 0) {
              dy[j] = tmp;
            } else {
              dy[j] += tmp;
            }
          }
        }
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int i = 0; i < pre; ++i) {
      for (int j = 0; j < n; ++j) {
        for (int k = 0; k < post; ++k) {
          int y_offset = i * n * post + j * post + k;
          if (dy != nullptr) {
            dy[y_offset] =
                dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
          }
          if (dx != nullptr) {
            T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
            if (i == 0 && k == 0) {
              dx[j] = tmp;
            } else {
              dx[j] += tmp;
            }
          }
        }
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonElementwiseBroadcastBackward(const CPUContext &ctx,
                                        const DDim &x_dims,
                                        const DDim &y_dims,
                                        const DenseTensor &x,
                                        const DenseTensor &y,
                                        const DenseTensor &out,
                                        const DenseTensor &dout,
                                        int axis,
                                        DenseTensor *dx,
                                        DenseTensor *dy,
                                        DX_OP dx_op,
                                        DY_OP dy_op) {
  int max_dim = std::max(x_dims.size(), y_dims.size());
  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  std::vector<int> x_dims_array(max_dim);
  std::vector<int> y_dims_array(max_dim);
  std::vector<int> out_dims_array(max_dim);
  GetBroadcastDimsArrays(x_dims,
                         y_dims,
                         x_dims_array.data(),
                         y_dims_array.data(),
                         out_dims_array.data(),
                         max_dim,
                         axis);
  // for inplace strategy. memset will make dx and dout clear and get wrong
  // result.
  if (dx && dx->IsSharedBufferWith(dout)) {
    dx->clear();
240 241
    dx->Resize(x_dims);
    ctx.template Alloc<T>(dx);
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
  }

  VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
          << phi::make_ddim(x_dims_array)
          << " ydim:" << phi::make_ddim(y_dims_array);

  CommonGradBroadcastCPU<T, DX_OP, DY_OP, Tout>(x,
                                                y,
                                                out,
                                                dout,
                                                dx,
                                                dy,
                                                x_dims_array.data(),
                                                y_dims_array.data(),
                                                out_dims_array.data(),
                                                max_dim,
                                                ctx,
                                                dx_op,
                                                dy_op);
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
                                      const DDim &x_dims,
                                      const DDim &y_dims,
                                      const DenseTensor &x,
                                      const DenseTensor &y,
                                      const DenseTensor &out,
                                      const DenseTensor &dout,
                                      int axis,
                                      DenseTensor *dx,
                                      DenseTensor *dy,
                                      DX_OP dx_op,
                                      DY_OP dy_op) {
  bool is_xsize_larger = true;

  int max_dim = x_dims.size();
  if (x_dims.size() < y_dims.size()) {
    is_xsize_larger = false;
    max_dim = y_dims.size();
  }

  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  PADDLE_ENFORCE_GE(
      axis,
      0,
      errors::InvalidArgument(
          "Axis should be great than or equal to 0, but received axis is %d.",
          axis));
291 292 293 294 295 296 297
  PADDLE_ENFORCE_LE(
      axis,
      max_dim,
      errors::InvalidArgument(
          "Axis should be less than or equal to %d, but received axis is %d.",
          max_dim,
          axis));
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404

  int pre, n, post, is_run_common_broadcast, axis_trim = 0;
  if (is_xsize_larger) {
    auto y_dims_trimed = TrimTrailingSingularDims(y_dims);
    axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
    GetMidDims(x_dims,
               y_dims_trimed,
               axis_trim,
               &pre,
               &n,
               &post,
               &is_run_common_broadcast);
  } else {
    auto x_dims_trimed = TrimTrailingSingularDims(x_dims);
    axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
    GetMidDims(y_dims,
               x_dims_trimed,
               axis_trim,
               &pre,
               &n,
               &post,
               &is_run_common_broadcast);
  }
  // special case for common backward implementation.
  if (is_run_common_broadcast) {
    CommonElementwiseBroadcastBackward<T, DX_OP, DY_OP, Tout>(
        ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    return;
  }
  if (post == 1) {
    ElemwiseGradBroadcast1CPU(x.data<T>(),
                              y.data<T>(),
                              out.data<Tout>(),
                              dout.data<Tout>(),
                              pre,
                              n,
                              is_xsize_larger,
                              dx_op,
                              dy_op,
                              dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
                              dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
  } else {
    ElemwiseGradBroadcast2CPU(x.data<T>(),
                              y.data<T>(),
                              out.data<Tout>(),
                              dout.data<Tout>(),
                              pre,
                              n,
                              post,
                              is_xsize_larger,
                              dx_op,
                              dy_op,
                              dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
                              dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast {
  const T *x_;
  const T *y_;
  const Tout *out_;
  const Tout *dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T *dx_;
  T *dy_;
};

template <typename DeviceContext,
          typename T,
          typename DX_OP,
          typename DY_OP,
          typename Tout = T>
void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx,
                                    const DDim &x_dim,
                                    const DDim &y_dim,
                                    const DenseTensor &x,
                                    const DenseTensor &y,
                                    const DenseTensor &out,
                                    const DenseTensor &dout,
                                    int axis,
                                    DenseTensor *dx,
                                    DenseTensor *dy,
                                    DX_OP dx_op,
                                    DY_OP dy_op) {
  size_t N = static_cast<size_t>(phi::product(x_dim));
  phi::funcs::ForRange<DeviceContext> for_range(dev_ctx, N);
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
      x.data<T>(),
      y.data<T>(),
      out.data<Tout>(),
      dout.data<Tout>(),
      dx_op,
      dy_op,
      dx == nullptr ? nullptr : dev_ctx.template Alloc<T>(dx),
      dy == nullptr ? nullptr : dev_ctx.template Alloc<T>(dy)});
405 406
}

407
#if defined(__NVCC__) || defined(__HIPCC__)
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
// Suppose only has contiguous dims
static inline bool CheckContiguousDims(const std::vector<int> &broadcast_pos) {
  for (int i = 1; i < broadcast_pos.size(); ++i) {
    if (broadcast_pos[i] != broadcast_pos[i - 1] + 1) {
      return false;
    }
  }
  return true;
}

inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
                                               int *x_trans_indexs,
                                               const int max_dim,
                                               const int x_one_size) {
  int diff = max_dim - x_one_size;
  std::copy_n(x_one_indexs, x_one_size, x_trans_indexs + diff);
  int p = 0;
  int q = diff;
  for (int i = 0; i < max_dim; ++i) {
    if (q < max_dim && i == x_trans_indexs[q]) {
      ++q;
    } else {
      x_trans_indexs[p++] = i;
    }
  }
}

// Check input can be split into 2 parts
static inline bool SplitDims(const std::vector<int> &y_broadcast_pos,
                             int max_dim) {
  bool can_split_dim2 = true;
  // must at start or end.
  if (y_broadcast_pos[0] != 0 &&
      y_broadcast_pos[y_broadcast_pos.size() - 1] != max_dim - 1) {
    can_split_dim2 = false;
  } else {
    for (int i = 1; i < y_broadcast_pos.size(); ++i) {
      // dim must be continue
      if (y_broadcast_pos[i] != y_broadcast_pos[i - 1] + 1) {
        can_split_dim2 = false;
        break;
      }
    }
  }
  return can_split_dim2;
}

inline void ComputeBroadcastKernelSize(int *x_dims_array,
                                       int *out_dims_array,
                                       int *x_blocks,
                                       int *x_threads,
                                       int max_dim) {
  *x_blocks = 1;
  *x_threads = 1;
  for (int i = 0; i < max_dim; i++) {
    if (x_dims_array[i] == out_dims_array[i]) {
      *x_blocks *= x_dims_array[i];
    } else {
      *x_threads *= out_dims_array[i];
    }
  }
}

template <typename T, typename OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x,
                                                            const T *y,
                                                            const Tout *out,
                                                            const Tout *dout,
                                                            int pre,
                                                            int n,
                                                            int post,
                                                            int y_pre,
                                                            int y_n,
                                                            int y_post,
                                                            bool is_xsize,
                                                            OP op,
                                                            T *dd) {
485 486
  int tid = THREAD_ID_X;
  int bid = BLOCK_ID_X;
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507

  T val(0);
  if (is_xsize) {
    // do reduce for x
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int x_offset = b_i * n * post + b_j;
      int out_offset = b_i * n * post + i * post + b_j;

      // Get y pre rows id with x post and y_pre.
      int b_yi = bid / (post * y_pre);
      int b_yj = bid % y_post;
      int y_offset = b_yi * y_n + i * y_post + b_yj;

      if (dd) {
        val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
      }
    }
    if (dd) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
508
      val = phi::backends::gpu::reduceSum(val, tid, h);
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
      if (tid == 0) {
        dd[bid] = val;
      }
    }
  } else {
    // do reduce for y
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int y_offset = b_i * n * post + b_j;
      int out_offset = b_i * n * post + i * post + b_j;

      int b_yi = bid / (post * y_pre);
      int b_yj = bid % y_post;
      int x_offset = b_yi * y_n + i * y_post + b_yj;

      if (dd) {
        val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
      }
    }
    if (dd) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
531
      val = phi::backends::gpu::reduceSum(val, tid, h);
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
      if (tid == 0) {
        dd[bid] = val;
      }
    }
  }
}

template <typename T, typename DY_OP, typename DX_OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastAllCUDAKernel(
    const T *x,
    const T *y,
    const Tout *out,
    const Tout *dout,
    int pre,
    int n,
    int post,
    bool is_xsize_larger,
    DX_OP dx_op,
    DY_OP dy_op,
    T *dx,
    T *dy) {
553 554
  int tid = THREAD_ID_X;
  int bid = BLOCK_ID_X;
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572

  T val(0);
  if (is_xsize_larger) {
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int x_offset = b_i * n * post + i * post + b_j;
      int y_offset = b_i * post + b_j;
      if (dx) {
        dx[x_offset] =
            dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
      }
      if (dy) {
        val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
      }
    }
    if (dy) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
573
      val = phi::backends::gpu::reduceSum(val, tid, h);
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593
      if (tid == 0) {
        dy[bid] = val;
      }
    }
  } else {
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int y_offset = b_i * n * post + i * post + b_j;
      int x_offset = b_i * post + b_j;
      if (dy) {
        dy[y_offset] =
            dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
      }
      if (dx) {
        val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
      }
    }
    if (dx) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
594
      val = phi::backends::gpu::reduceSum(val, tid, h);
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
      if (tid == 0) {
        dx[bid] = val;
      }
    }
  }
}

template <typename T, typename DY_OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x,
                                                               const T *y,
                                                               const Tout *out,
                                                               const Tout *dout,
                                                               int h,
                                                               int w,
                                                               DY_OP dy_op,
                                                               T *dy,
                                                               int x_h,
                                                               int x_w,
                                                               bool is_y) {
  __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];

  T val(0);
617 618
  size_t width_stride = GRID_NUM_X * BLOCK_NUM_X;
  size_t idx = THREAD_ID_X + BLOCK_NUM_X * BLOCK_ID_X;
619 620 621 622 623 624
  size_t full_width =
      (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
  size_t full_height =
      (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
  if (is_y) {
    for (int m = idx; m < full_width; m += width_stride) {
625 626
      sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
      for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
627 628 629 630 631
        int out_offset = n * w + m;
        int x_offset = (n % x_h) * x_w + m % x_w;
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]);
632
            sdata[THREAD_ID_Y][THREAD_ID_X] += val;
633 634 635 636 637
          }
          __syncthreads();
        }
      }
      if (dy) {
638
        T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
639
        for (int i = warpSize >> 1; i > 0; i >>= 1) {
640 641
          my_val +=
              phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
642 643
        }
        __syncthreads();
644 645
        if ((THREAD_ID_X == 0)) {
          sdata[0][THREAD_ID_Y] = my_val;
646 647
        }
        __syncthreads();
648 649
        if (THREAD_ID_Y == 0 && m < w) {
          dy[m] = sdata[0][THREAD_ID_X];
650 651 652 653 654
        }
      }
    }
  } else {
    for (int m = idx; m < full_width; m += width_stride) {
655 656
      sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
      for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
657 658 659 660 661
        int out_offset = n * w + m;
        int y_offset = (n % x_h) * x_w + m % x_w;
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]);
662
            sdata[THREAD_ID_Y][THREAD_ID_X] += val;
663 664 665 666 667
          }
          __syncthreads();
        }
      }
      if (dy) {
668
        T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
669
        for (int i = warpSize >> 1; i > 0; i >>= 1) {
670 671
          my_val +=
              phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
672 673
        }
        __syncthreads();
674 675
        if ((THREAD_ID_X == 0)) {
          sdata[0][THREAD_ID_Y] = my_val;
676 677
        }
        __syncthreads();
678 679
        if (THREAD_ID_Y == 0 && m < w) {
          dy[m] = sdata[0][THREAD_ID_X];
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
        }
      }
    }
  }
}

template <typename T, typename DY_OP, typename Tout = T>
static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x,
                                                            const T *y,
                                                            const Tout *out,
                                                            const Tout *dout,
                                                            int h,
                                                            int w,
                                                            DY_OP dy_op,
                                                            T *dy,
                                                            int x_h,
                                                            int x_w,
                                                            bool is_y) {
698 699 700
  int j = BLOCK_ID_X;
  int i = THREAD_ID_X;
  int tid = THREAD_ID_X;
701 702 703 704 705 706 707 708 709 710 711 712 713 714
  T val(0);

  if (is_y) {
    do {
      int out_offset = i * w + j;
      int x_offset = (i % x_h) * x_w + j % x_w;
      if (dy) {
        val += dy_op(x[x_offset], y[j], out[out_offset], dout[out_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
715
      val = phi::backends::gpu::reduceSum(val, tid, h);
716
      if (THREAD_ID_X == 0) {
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731
        dy[j] = val;
      }
    }
  } else {
    do {
      int out_offset = i * w + j;
      int y_offset = (i % x_h) * x_w + j % x_w;
      if (dy) {
        val += dy_op(x[j], y[y_offset], out[out_offset], dout[out_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
732
      val = phi::backends::gpu::reduceSum(val, tid, h);
733
      if (THREAD_ID_X == 0) {
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
        dy[j] = val;
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x,
                                                        const T *y,
                                                        const Tout *out,
                                                        const Tout *dout,
                                                        int h,
                                                        int w,
                                                        bool is_xsize_larger,
                                                        DX_OP dx_op,
                                                        DY_OP dy_op,
                                                        T *dx,
                                                        T *dy) {
752 753 754
  int j = BLOCK_ID_X;
  int i = THREAD_ID_X;
  int tid = THREAD_ID_X;
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
  T val(0);
  if (is_xsize_larger) {
    do {
      int x_offset = i * w + j;
      if (dx) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy) {
        val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
770
      val = phi::backends::gpu::reduceSum(val, tid, h);
771
      if (THREAD_ID_X == 0) {
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
        dy[j] = val;
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    do {
      int y_offset = i * w + j;
      if (dy) {
        dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }
      if (dx) {
        val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dx) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
789
      val = phi::backends::gpu::reduceSum(val, tid, h);
790
      if (THREAD_ID_X == 0) {
791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
        dx[j] = val;
      }
    }
  }
}

// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
    const T *x,
    const T *y,
    const Tout *out,
    const Tout *dout,
    int h,
    int w,
    bool is_xsize_larger,
    DX_OP dx_op,
    DY_OP dy_op,
    T *dx,
    T *dy) {
  __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];

  T val(0);
815 816
  size_t width_stride = GRID_NUM_X * BLOCK_NUM_X;
  size_t idx = THREAD_ID_X + BLOCK_NUM_X * BLOCK_ID_X;
817 818 819 820 821 822
  size_t full_width =
      (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
  size_t full_height =
      (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
  if (is_xsize_larger) {
    for (int m = idx; m < full_width; m += width_stride) {
823 824
      sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
      for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
825 826 827 828 829 830 831 832
        int x_offset = n * w + m;
        if (dx && m < w && n < h) {
          dx[x_offset] =
              dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
        }
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
833
            sdata[THREAD_ID_Y][THREAD_ID_X] += val;
834 835 836 837 838
          }
          __syncthreads();
        }
      }
      if (dy) {
839
        T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
840
        for (int i = warpSize >> 1; i > 0; i >>= 1)
841 842
          my_val +=
              phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
843
        __syncthreads();
844 845
        if ((THREAD_ID_X == 0)) {
          sdata[0][THREAD_ID_Y] = my_val;
846 847
        }
        __syncthreads();
848 849
        if (THREAD_ID_Y == 0 && m < w) {
          dy[m] = sdata[0][THREAD_ID_X];
850 851 852 853 854
        }
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int m = idx; m < full_width; m += width_stride) {
855 856
      sdata[THREAD_ID_Y][THREAD_ID_X] = 0;
      for (int n = THREAD_ID_Y; n < full_height; n += BLOCK_Y) {
857 858 859 860 861 862 863 864
        int y_offset = n * w + m;
        if (dy && m < w && n < h) {
          dy[y_offset] =
              dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
        }
        if (dx) {
          if (m < w && n < h) {
            T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
865
            sdata[THREAD_ID_Y][THREAD_ID_X] += val;
866 867 868 869 870
          }
          __syncthreads();
        }
      }
      if (dx) {
871
        T my_val = sdata[THREAD_ID_X][THREAD_ID_Y];
872
        for (int i = warpSize >> 1; i > 0; i >>= 1)
873 874
          my_val +=
              phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
875
        __syncthreads();
876 877
        if ((THREAD_ID_X == 0)) {
          sdata[0][THREAD_ID_Y] = my_val;
878 879
        }
        __syncthreads();
880 881
        if (THREAD_ID_Y == 0 && m < w) {
          dx[m] = sdata[0][THREAD_ID_X];
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900
        }
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x,
                                                        const T *y,
                                                        const Tout *out,
                                                        const Tout *dout,
                                                        int pre,
                                                        int n,
                                                        int post,
                                                        bool is_xsize_larger,
                                                        DX_OP dx_op,
                                                        DY_OP dy_op,
                                                        T *dx,
                                                        T *dy) {
901 902
  int tid = THREAD_ID_X;
  int j = BLOCK_ID_X;
903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928

  T val(0);
  int ttid = tid;

  if (is_xsize_larger) {
    while (true) {
      int i = ttid / post;
      int k = ttid % post;
      if (i >= pre) break;

      int x_offset = i * n * post + j * post + k;

      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }

      if (dy != nullptr) {
        val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }

      ttid += ELEMWISE_MAX_BLOCK_DIM;
    }

    if (dy) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
929
      val = phi::backends::gpu::reduceSum(val, tid, h);
930
      if (THREAD_ID_X == 0) {
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955
        dy[j] = val;
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    while (true) {
      int i = ttid / post;
      int k = ttid % post;
      if (i >= pre) break;

      int y_offset = i * n * post + j * post + k;

      if (dy != nullptr) {
        dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }

      if (dx != nullptr) {
        val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }

      ttid += ELEMWISE_MAX_BLOCK_DIM;
    }

    if (dx) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
956
      val = phi::backends::gpu::reduceSum(val, tid, h);
957
      if (THREAD_ID_X == 0) {
958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
        dx[j] = val;
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
                                       const T *x,
                                       const T *y,
                                       const Tout *out,
                                       const Tout *dout,
                                       int h,
                                       int w,
                                       bool is_xsize_larger,
                                       DX_OP dx_op,
                                       DY_OP dy_op,
                                       T *dx,
                                       T *dy) {
  // For small case use 1D block
  constexpr int half_walf = 16;
  if (w < half_walf || h < half_walf) {
    int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
981 982
    int grid_size = w;
    ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
983 984 985 986
        x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
  } else {
    // suppose perfoemance improves with h increased.
    dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
987
    dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
988
    auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
989 990
    auto *ctx = static_cast<GPUContext *>(
        paddle::platform::DeviceContextPool::Instance().Get(gplace));
991
    phi::backends::gpu::LimitGridDim(*ctx, &grid_size);
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
        x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
                                       const T *x,
                                       const T *y,
                                       const Tout *out,
                                       const Tout *dout,
                                       int pre,
                                       int n,
                                       int post,
                                       bool is_xsize_larger,
                                       DX_OP dx_op,
                                       DY_OP dy_op,
                                       T *dx,
                                       T *dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
1012
  dim3 grid_size = dim3(n);
1013
  auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
1014 1015
  auto *ctx = static_cast<GPUContext *>(
      paddle::platform::DeviceContextPool::Instance().Get(gplace));
1016
  phi::backends::gpu::LimitGridDim(*ctx, &grid_size);
1017
  ElemwiseGradBroadcast2CUDAKernel<<<grid_size, block_size, 0, stream>>>(
1018 1019 1020
      x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
}

1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
template <typename T, typename DX_OP, typename Tout = T>
__global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array,
                                              const int *y_strides_array,
                                              const int *out_dims_array,
                                              const int *y_strides_order,
                                              const int *y_dims_order,
                                              const T *x,
                                              const T *y,
                                              const Tout *out,
                                              const Tout *dout,
                                              T *dx,
                                              int out_size,
                                              int max_dim,
                                              int thread_num,
                                              DX_OP dx_op) {
  T val(0);
1037 1038 1039
  int i = BLOCK_ID_X;
  int tid = THREAD_ID_X;
  for (int j = tid; j < thread_num; j += BLOCK_NUM_X) {
1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
    const int X_index = i * thread_num + j;
    int out_index = X_index;
    int C_index = 0;
    int B_index = i * thread_num + j;
    int remainder = 0;
#pragma unroll
    for (int d = max_dim - 1; d >= 0; --d) {
      GetDivMod(B_index, y_dims_order[d], &B_index, &remainder);
      C_index += remainder * y_strides_order[d];
    }
    int x_index = 0;
    int y_index = 0;
    int C_index_val = C_index;
#pragma unroll
    for (int d = max_dim - 1; d >= 0; --d) {
      GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder);
      x_index += remainder * x_strides_array[d];
      y_index += remainder * y_strides_array[d];
    }
    out_index = C_index;
    val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
  }
1062
  val = phi::backends::gpu::reduceSum(val, tid, thread_num);
1063
  if (THREAD_ID_X == 0) {
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
    dx[i] = val;
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCUDA(const DenseTensor &x,
                             const DenseTensor &y,
                             const DenseTensor &out,
                             const DenseTensor &dout,
                             DenseTensor *dx,
                             DenseTensor *dy,
                             int *x_dims_array,
                             int *y_dims_array,
                             int *out_dims_array,
                             int max_dim,
                             const GPUContext &ctx,
                             DX_OP dx_op,
                             DY_OP dy_op) {
1082
  const auto gplace = ctx.GetPlace();
1083
  auto cplace = phi::CPUPlace();
1084 1085 1086 1087
  const T *x_data = x.data<T>();
  const T *y_data = y.data<T>();
  const Tout *out_data = out.data<Tout>();
  const Tout *dout_data = dout.data<Tout>();
1088 1089
  T *dx_data = dx == nullptr ? nullptr : ctx.Alloc<T>(dx);
  T *dy_data = dy == nullptr ? nullptr : ctx.Alloc<T>(dy);
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217

  std::vector<int> x_one_indexs;
  std::vector<int> y_one_indexs;
  for (int i = 0; i < max_dim; i++) {
    if (x_dims_array[i] != y_dims_array[i]) {
      if (x_dims_array[i] == 1) {
        x_one_indexs.push_back(i);
      }
      if (y_dims_array[i] == 1) {
        y_one_indexs.push_back(i);
      }
    }
  }

  std::vector<int> x_trans_indexs(max_dim);
  std::vector<int> y_trans_indexs(max_dim);
  ComputeBroadcastTranspositionArray(
      x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size());
  ComputeBroadcastTranspositionArray(
      y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size());

  // compute array stride for cuda kernel;
  // e.g. x.dims=[2,3,4], x_stride=[12,4,1]
  std::vector<int> x_strides_array(max_dim);
  std::vector<int> y_strides_array(max_dim);
  std::vector<int> out_strides_array(max_dim);
  int x_stride = 1;
  int y_stride = 1;
  int z_stride = 1;
  for (int i = max_dim - 1; i >= 0; i--) {
    x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
    y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
    out_strides_array[i] = z_stride;
    x_stride *= x_dims_array[i];
    y_stride *= y_dims_array[i];
    z_stride *= out_dims_array[i];
  }

  std::vector<int> x_strides_order(max_dim);
  std::vector<int> y_strides_order(max_dim);
  std::vector<int> x_dims_order(max_dim);
  std::vector<int> y_dims_order(max_dim);
  for (int i = 0; i < max_dim; ++i) {
    x_strides_order[i] = out_strides_array[x_trans_indexs[i]];
    y_strides_order[i] = out_strides_array[y_trans_indexs[i]];
    x_dims_order[i] = out_dims_array[x_trans_indexs[i]];
    y_dims_order[i] = out_dims_array[y_trans_indexs[i]];
  }
  std::vector<int> x_broadcast_pos;
  std::vector<int> y_broadcast_pos;

  int bytes = max_dim * sizeof(int);

  for (int i = 0; i < max_dim; ++i) {
    if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) {
      x_broadcast_pos.emplace_back(i);
    }
    if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) {
      y_broadcast_pos.emplace_back(i);
    }
  }

  auto stream = ctx.stream();
  bool can_split_x = false;
  bool can_split_y = false;

  auto FastCommonCUDAF = [&](const std::vector<int> &broadcast_pos, bool is_y) {
    int h = std::accumulate(out_dims_array,
                            out_dims_array + broadcast_pos.size(),
                            1,
                            std::multiplies<int>());
    int w = std::accumulate(out_dims_array + broadcast_pos.size(),
                            out_dims_array + max_dim,
                            1,
                            std::multiplies<int>());

    VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h
            << " is_y:" << is_y;

    int split_h;
    int split_w;
    int kh = h;
    int kw = w;

    if (is_y) {
      split_h = std::accumulate(x_dims_array,
                                x_dims_array + broadcast_pos.size(),
                                1,
                                std::multiplies<int>());
      split_w = std::accumulate(x_dims_array + broadcast_pos.size(),
                                x_dims_array + max_dim,
                                1,
                                std::multiplies<int>());

    } else {
      split_h = std::accumulate(y_dims_array,
                                y_dims_array + broadcast_pos.size(),
                                1,
                                std::multiplies<int>());
      split_w = std::accumulate(y_dims_array + broadcast_pos.size(),
                                y_dims_array + max_dim,
                                1,
                                std::multiplies<int>());
    }

    if (h > split_h) kh = split_h;
    if (w > split_w) kw = split_w;

    if (is_y) {
      if (w < 16 || h < 16) {
        int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
        int grid_size = w;
        CommonGradBroadcast1CUDAKernelHeight<<<grid_size,
                                               block_size,
                                               0,
                                               stream>>>(x_data,
                                                         y_data,
                                                         out_data,
                                                         dout_data,
                                                         h,
                                                         w,
                                                         dy_op,
                                                         dy_data,
                                                         kh,
                                                         kw,
                                                         is_y);
      } else {
        dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
1218
        dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
1219
        phi::backends::gpu::LimitGridDim(ctx, &grid_size);
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254
        FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
                                                  block_size,
                                                  0,
                                                  stream>>>(x_data,
                                                            y_data,
                                                            out_data,
                                                            dout_data,
                                                            h,
                                                            w,
                                                            dy_op,
                                                            dy_data,
                                                            kh,
                                                            kw,
                                                            is_y);
      }
    } else {
      if (w < 16 || h < 16) {
        int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
        int grid_size = w;
        CommonGradBroadcast1CUDAKernelHeight<<<grid_size,
                                               block_size,
                                               0,
                                               stream>>>(x_data,
                                                         y_data,
                                                         out_data,
                                                         dout_data,
                                                         h,
                                                         w,
                                                         dx_op,
                                                         dx_data,
                                                         kh,
                                                         kw,
                                                         is_y);
      } else {
        dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
1255
        dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
1256
        phi::backends::gpu::LimitGridDim(ctx, &grid_size);
1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
        FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
                                                  block_size,
                                                  0,
                                                  stream>>>(x_data,
                                                            y_data,
                                                            out_data,
                                                            dout_data,
                                                            h,
                                                            w,
                                                            dx_op,
                                                            dx_data,
                                                            kh,
                                                            kw,
                                                            is_y);
      }
    }
  };

  auto FastBroadCastHeightCUDAF = [&](const std::vector<int> &broadcast_pos,
                                      bool x_large) {
    int h = std::accumulate(out_dims_array,
                            out_dims_array + broadcast_pos.size(),
                            1,
                            std::multiplies<int>());
    int w = std::accumulate(out_dims_array + broadcast_pos.size(),
                            out_dims_array + max_dim,
                            1,
                            std::multiplies<int>());

    VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h;

    if (w < 16 || h < 16) {
      int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
      int grid_size = w;
      ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
          x_data,
          y_data,
          out_data,
          dout_data,
          h,
          w,
          x_large,
          dx_op,
          dy_op,
          dx_data,
          dy_data);
    } else {
      dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
      int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
      FastElemwiseGradBroadcast1CUDAKernel<<<grid_size,
                                             block_size,
                                             0,
                                             stream>>>(x_data,
                                                       y_data,
                                                       out_data,
                                                       dout_data,
                                                       h,
                                                       w,
                                                       x_large,
                                                       dx_op,
                                                       dy_op,
                                                       dx_data,
                                                       dy_data);
    }
  };

1323 1324 1325
  auto FastBroadCastAllCUDAF = [&](const std::vector<int> &broadcast_pos,
                                   int max_dim,
                                   bool is_x_large) {
1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352
    int axis = broadcast_pos[0];
    int pre = std::accumulate(
        out_dims_array, out_dims_array + axis, 1, std::multiplies<int>());
    int mid = 1;
    int post = 1;

    if (broadcast_pos.size() == 1) {
      mid = out_dims_array[axis];
      post = std::accumulate(out_dims_array + axis + 1,
                             out_dims_array + max_dim,
                             1,
                             std::multiplies<int>());
    } else {
      mid = std::accumulate(out_dims_array + axis,
                            out_dims_array + broadcast_pos.back() + 1,
                            1,
                            std::multiplies<int>());
      post = std::accumulate(out_dims_array + broadcast_pos.back() + 1,
                             out_dims_array + max_dim,
                             1,
                             std::multiplies<int>());
    }

    VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid
            << " post:" << post;

    int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
1353
    dim3 grid_size = dim3(pre * post);
1354
    phi::backends::gpu::LimitGridDim(ctx, &grid_size);
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370

    FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
        x_data,
        y_data,
        out_data,
        dout_data,
        pre,
        mid,
        post,
        is_x_large,
        dx_op,
        dy_op,
        dx_data,
        dy_data);
  };

1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395
  auto FastBroadCastOneCUDAF =
      [&](const std::vector<int> &broadcast_pos, int max_dim, bool is_x) {
        int axis = broadcast_pos[0];
        int pre = std::accumulate(
            out_dims_array, out_dims_array + axis, 1, std::multiplies<int>());
        int mid = out_dims_array[axis];
        int post = std::accumulate(out_dims_array + axis + 1,
                                   out_dims_array + max_dim,
                                   1,
                                   std::multiplies<int>());

        int k_pre;
        int k_mid;
        int k_post;

        if (is_x) {
          k_pre = std::accumulate(
              y_dims_array, y_dims_array + axis, 1, std::multiplies<int>());
          k_mid = y_dims_array[axis];
          k_post = std::accumulate(y_dims_array + axis + 1,
                                   y_dims_array + max_dim,
                                   1,
                                   std::multiplies<int>());
          int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
          dim3 grid_size = dim3(pre * post);
1396
          phi::backends::gpu::LimitGridDim(ctx, &grid_size);
1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426
          // we need to calc y offset with blockid, so do x_pre/y_pre to get
          // left size.
          if (k_pre != pre) k_pre = pre / k_pre;

          FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
                                                 block_size,
                                                 0,
                                                 stream>>>(x_data,
                                                           y_data,
                                                           out_data,
                                                           dout_data,
                                                           pre,
                                                           mid,
                                                           post,
                                                           k_pre,
                                                           k_mid,
                                                           k_post,
                                                           true,
                                                           dx_op,
                                                           dx_data);
        } else {
          k_pre = std::accumulate(
              x_dims_array, x_dims_array + axis, 1, std::multiplies<int>());
          k_mid = x_dims_array[axis];
          k_post = std::accumulate(x_dims_array + axis + 1,
                                   x_dims_array + max_dim,
                                   1,
                                   std::multiplies<int>());
          int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
          dim3 grid_size = dim3(pre * post);
1427
          phi::backends::gpu::LimitGridDim(ctx, &grid_size);
1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449
          if (k_pre != pre) k_pre = pre / k_pre;

          FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
                                                 block_size,
                                                 0,
                                                 stream>>>(x_data,
                                                           y_data,
                                                           out_data,
                                                           dout_data,
                                                           pre,
                                                           mid,
                                                           post,
                                                           k_pre,
                                                           k_mid,
                                                           k_post,
                                                           false,
                                                           dy_op,
                                                           dy_data);
        }
        VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid
                << " post:" << post;
      };
1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532

  // do fast elementwise if: 1. only one input need to do broadcast, we can
  // fallback
  // to old fast path.
  // 2. if both x and y need broadcast, then do it one by one.
  bool fast_broadcast = false;
  if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
    can_split_y = SplitDims(y_broadcast_pos, max_dim);
    if (can_split_y) {
      // only y need to do broadcast on h
      if (y_broadcast_pos[0] == 0) {
        FastBroadCastHeightCUDAF(y_broadcast_pos, true);
        fast_broadcast = true;
      }
    } else if (y_broadcast_pos.size() == 1 ||
               CheckContiguousDims(y_broadcast_pos)) {  // for only one dim and
                                                        // contiguous broadcast.
      // If cannot split,  which means input has 3 parts
      FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
      fast_broadcast = true;
    }
  } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
    // only x need broadcast
    can_split_x = SplitDims(x_broadcast_pos, max_dim);
    if (can_split_x) {
      if (x_broadcast_pos[0] == 0) {
        FastBroadCastHeightCUDAF(x_broadcast_pos, false);
        fast_broadcast = true;
      }
    } else if (x_broadcast_pos.size() == 1 ||
               CheckContiguousDims(x_broadcast_pos)) {
      FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
      fast_broadcast = true;
    }
  } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
    // do x and y broadcast each.
    can_split_y = SplitDims(y_broadcast_pos, max_dim);
    bool fast_broadcast_x = false;
    bool fast_broadcast_y = false;
    if (can_split_y) {
      // begin at start.
      if (y_broadcast_pos[0] == 0) {
        FastCommonCUDAF(y_broadcast_pos, true);
        fast_broadcast_y = true;
      }
    } else if (y_broadcast_pos.size() == 1) {
      FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
      can_split_y = true;
      fast_broadcast_y = true;
    }
    can_split_x = SplitDims(x_broadcast_pos, max_dim);
    if (can_split_x) {
      if (x_broadcast_pos[0] == 0) {
        FastCommonCUDAF(x_broadcast_pos, false);
        fast_broadcast_x = true;
      }
    } else if (x_broadcast_pos.size() == 1) {
      FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
      can_split_x = true;
      fast_broadcast_x = true;
    }
    VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
            << " can_split_x:" << can_split_x;
    // if both x and y into fast path then return
    if (fast_broadcast_x && fast_broadcast_y) {
      fast_broadcast = true;
    }
    if (can_split_y && can_split_x && fast_broadcast) return;
  }

  // Should remove memory copy, use reg instead.
  if (fast_broadcast) {
    return;
  }
  int x_blocks = 0;
  int x_threads = 0;
  ComputeBroadcastKernelSize(
      x_dims_array, out_dims_array, &x_blocks, &x_threads, max_dim);
  int y_blocks = 0;
  int y_threads = 0;
  ComputeBroadcastKernelSize(
      y_dims_array, out_dims_array, &y_blocks, &y_threads, max_dim);

1533 1534 1535 1536
  auto x_strides_array_tmp = paddle::memory::Alloc(
      ctx.GetPlace(),
      bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1537 1538 1539 1540 1541 1542 1543 1544 1545
  int *x_strides_array_gpu =
      reinterpret_cast<int *>(x_strides_array_tmp->ptr());
  paddle::memory::Copy(gplace,
                       x_strides_array_gpu,
                       cplace,
                       x_strides_array.data(),
                       bytes,
                       ctx.stream());

1546 1547 1548 1549
  auto y_strides_array_tmp = paddle::memory::Alloc(
      ctx.GetPlace(),
      bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1550 1551 1552 1553 1554 1555 1556 1557 1558
  int *y_strides_array_gpu =
      reinterpret_cast<int *>(y_strides_array_tmp->ptr());
  paddle::memory::Copy(gplace,
                       y_strides_array_gpu,
                       cplace,
                       y_strides_array.data(),
                       bytes,
                       ctx.stream());

1559 1560 1561 1562
  auto out_dims_array_tmp = paddle::memory::Alloc(
      ctx.GetPlace(),
      bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1563 1564 1565 1566 1567 1568 1569 1570 1571
  int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
  paddle::memory::Copy(
      gplace, out_dims_array_gpu, cplace, out_dims_array, bytes, ctx.stream());

  const int out_size = std::accumulate(
      out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
  int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
  int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
  if (dx) {
1572 1573 1574 1575
    auto x_strides_order_tmp = paddle::memory::Alloc(
        ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1576 1577 1578 1579 1580 1581 1582 1583 1584
    int *x_strides_order_gpu =
        reinterpret_cast<int *>(x_strides_order_tmp->ptr());
    paddle::memory::Copy(gplace,
                         x_strides_order_gpu,
                         cplace,
                         x_strides_order.data(),
                         bytes,
                         ctx.stream());

1585 1586 1587 1588
    auto x_dims_order_tmp = paddle::memory::Alloc(
        ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1589 1590 1591 1592 1593 1594 1595
    int *x_dims_order_gpu = reinterpret_cast<int *>(x_dims_order_tmp->ptr());
    paddle::memory::Copy(gplace,
                         x_dims_order_gpu,
                         cplace,
                         x_dims_order.data(),
                         bytes,
                         ctx.stream());
1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610
    CommonGradBroadcastCUDAKernel<T, DX_OP, Tout>
        <<<x_blocks, x_block_size, 0, ctx.stream()>>>(x_strides_array_gpu,
                                                      y_strides_array_gpu,
                                                      out_dims_array_gpu,
                                                      x_strides_order_gpu,
                                                      x_dims_order_gpu,
                                                      x_data,
                                                      y_data,
                                                      out_data,
                                                      dout_data,
                                                      dx_data,
                                                      out_size,
                                                      max_dim,
                                                      x_threads,
                                                      dx_op);
1611 1612
  }
  if (dy) {
1613 1614 1615 1616
    auto y_strides_order_tmp = paddle::memory::Alloc(
        ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1617 1618 1619 1620 1621 1622 1623 1624 1625
    int *y_strides_order_gpu =
        reinterpret_cast<int *>(y_strides_order_tmp->ptr());
    paddle::memory::Copy(gplace,
                         y_strides_order_gpu,
                         cplace,
                         y_strides_order.data(),
                         bytes,
                         ctx.stream());

1626 1627 1628 1629
    auto y_dims_order_tmp = paddle::memory::Alloc(
        ctx.GetPlace(),
        bytes,
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
1630 1631 1632 1633 1634 1635 1636
    int *y_dims_order_gpu = reinterpret_cast<int *>(y_dims_order_tmp->ptr());
    paddle::memory::Copy(gplace,
                         y_dims_order_gpu,
                         cplace,
                         y_dims_order.data(),
                         bytes,
                         ctx.stream());
1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651
    CommonGradBroadcastCUDAKernel<T, DY_OP, Tout>
        <<<y_blocks, y_block_size, 0, ctx.stream()>>>(x_strides_array_gpu,
                                                      y_strides_array_gpu,
                                                      out_dims_array_gpu,
                                                      y_strides_order_gpu,
                                                      y_dims_order_gpu,
                                                      x_data,
                                                      y_data,
                                                      out_data,
                                                      dout_data,
                                                      dy_data,
                                                      out_size,
                                                      max_dim,
                                                      y_threads,
                                                      dy_op);
1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
  }
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonElementwiseBroadcastBackward(const GPUContext &ctx,
                                        const DDim &x_dims,
                                        const DDim &y_dims,
                                        const DenseTensor &x,
                                        const DenseTensor &y,
                                        const DenseTensor &out,
                                        const DenseTensor &dout,
                                        int axis,
                                        DenseTensor *dx,
                                        DenseTensor *dy,
                                        DX_OP dx_op,
                                        DY_OP dy_op) {
  int max_dim = std::max(x_dims.size(), y_dims.size());
  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  std::vector<int> x_dims_array(max_dim);
  std::vector<int> y_dims_array(max_dim);
  std::vector<int> out_dims_array(max_dim);
1673 1674 1675 1676 1677 1678 1679
  GetBroadcastDimsArrays(x_dims,
                         y_dims,
                         x_dims_array.data(),
                         y_dims_array.data(),
                         out_dims_array.data(),
                         max_dim,
                         axis);
1680 1681 1682 1683
  // for inplace strategy. memset will make dx and dout clear and get wrong
  // result.
  if (dx && dx->IsSharedBufferWith(dout)) {
    dx->clear();
1684 1685
    dx->Resize(x_dims);
    ctx.template Alloc<T>(dx);
1686 1687 1688
  }

  VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
1689 1690
          << phi::make_ddim(x_dims_array)
          << " ydim:" << phi::make_ddim(y_dims_array);
1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731

  CommonGradBroadcastCUDA<T, DX_OP, DY_OP, Tout>(x,
                                                 y,
                                                 out,
                                                 dout,
                                                 dx,
                                                 dy,
                                                 x_dims_array.data(),
                                                 y_dims_array.data(),
                                                 out_dims_array.data(),
                                                 max_dim,
                                                 ctx,
                                                 dx_op,
                                                 dy_op);
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
                                      const DDim &x_dims,
                                      const DDim &y_dims,
                                      const DenseTensor &x,
                                      const DenseTensor &y,
                                      const DenseTensor &out,
                                      const DenseTensor &dout,
                                      int axis,
                                      DenseTensor *dx,
                                      DenseTensor *dy,
                                      DX_OP dx_op,
                                      DY_OP dy_op) {
  bool is_xsize_larger = true;

  int max_dim = x_dims.size();
  if (x_dims.size() < y_dims.size()) {
    is_xsize_larger = false;
    max_dim = y_dims.size();
  }

  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  PADDLE_ENFORCE_GE(
      axis,
      0,
1732
      errors::InvalidArgument(
1733 1734
          "Axis should be great than or equal to 0, but received axis is %d.",
          axis));
1735 1736 1737 1738 1739 1740 1741
  PADDLE_ENFORCE_LE(
      axis,
      max_dim,
      errors::InvalidArgument(
          "Axis should be less than or equal to %d, but received axis is %d.",
          max_dim,
          axis));
1742 1743 1744

  int pre, n, post, is_run_common_broadcast, axis_trim = 0;
  if (is_xsize_larger) {
1745
    auto y_dims_trimed = TrimTrailingSingularDims(y_dims);
1746
    axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
1747 1748 1749 1750 1751 1752 1753
    GetMidDims(x_dims,
               y_dims_trimed,
               axis_trim,
               &pre,
               &n,
               &post,
               &is_run_common_broadcast);
1754
  } else {
1755
    auto x_dims_trimed = TrimTrailingSingularDims(x_dims);
1756
    axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
1757 1758 1759 1760 1761 1762 1763
    GetMidDims(y_dims,
               x_dims_trimed,
               axis_trim,
               &pre,
               &n,
               &post,
               &is_run_common_broadcast);
1764 1765 1766 1767 1768 1769 1770 1771
  }
  // special case for common backward implementation.
  if (is_run_common_broadcast) {
    CommonElementwiseBroadcastBackward<T, DX_OP, DY_OP, Tout>(
        ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    return;
  }
  if (post == 1) {
1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783
    ElemwiseGradBroadcast1CUDA(ctx.stream(),
                               x.data<T>(),
                               y.data<T>(),
                               out.data<Tout>(),
                               dout.data<Tout>(),
                               pre,
                               n,
                               is_xsize_larger,
                               dx_op,
                               dy_op,
                               dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
                               dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
1784
  } else {
1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797
    ElemwiseGradBroadcast2CUDA(ctx.stream(),
                               x.data<T>(),
                               y.data<T>(),
                               out.data<Tout>(),
                               dout.data<Tout>(),
                               pre,
                               n,
                               post,
                               is_xsize_larger,
                               dx_op,
                               dy_op,
                               dx == nullptr ? nullptr : ctx.Alloc<T>(dx),
                               dy == nullptr ? nullptr : ctx.Alloc<T>(dy));
1798 1799 1800
  }
}

1801 1802
#endif

1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828
template <typename DeviceContext,
          typename T,
          typename DX_OP,
          typename DY_OP,
          typename Tout = T>
void ElemwiseGradCompute(const DeviceContext &dev_ctx,
                         const DenseTensor &x,
                         const DenseTensor &y,
                         const DenseTensor &out,
                         const DenseTensor &dout,
                         int axis,
                         DenseTensor *dx,
                         DenseTensor *dy,
                         DX_OP dx_op,
                         DY_OP dy_op) {
  const DDim &x_dim = x.dims();
  const DDim &y_dim = y.dims();
  if (x.dims() == y.dims()) {
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP, Tout>(
        dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  } else {
    ElemwiseGradComputeWithBroadcast<T, DX_OP, DY_OP, Tout>(
        dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

1829
}  // namespace funcs
1830
}  // namespace phi