elementwise_op_function.h 96.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16

17
#include <glog/logging.h>
18

19
#include <algorithm>
20
#include <functional>  // for multiplies
D
dzhwinter 已提交
21
#include <iterator>
22
#include <vector>
23

Y
Yi Wang 已提交
24 25 26
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
27
#include "paddle/fluid/memory/malloc.h"
28
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
29
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
Y
Yi Wang 已提交
30
#include "paddle/fluid/platform/transform.h"
31

32 33
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
C
Chen Weihang 已提交
34 35
#include "paddle/pten/kernels/hybird/cpu/elementwise.h"
#include "paddle/pten/kernels/hybird/general/elementwise_base.h"
36

37
#if defined(__NVCC__) || defined(__HIPCC__)
C
chengduoZH 已提交
38
#ifdef __NVCC__
39
#include <cuda.h>
40 41 42
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#endif
C
chengduoZH 已提交
43
#include <thrust/iterator/iterator_adaptor.h>
44

45
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
46 47
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
48

R
ronnywang 已提交
49 50 51
#ifdef __HIPCC__
constexpr int ELEMWISE_MAX_BLOCK_DIM = 256;
#else
Y
Yu Yang 已提交
52
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
R
ronnywang 已提交
53
#endif
54 55
#define BLOCK_X 32
#define BLOCK_Y 32
C
chengduoZH 已提交
56 57
#endif

Y
Yi Wang 已提交
58
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
59
#include "paddle/fluid/platform/for_range.h"
60 61 62 63 64 65
#define GetDivMod(dividend, divisor, div, mod) \
  do {                                         \
    const auto dividend_copy = dividend;       \
    *div = dividend_copy / divisor;            \
    *mod = dividend_copy % divisor;            \
  } while (0)
66

67 68 69 70
#define DIVUP(x, y) (((x) + (y)-1) / (y))

#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))

71 72 73
namespace paddle {
namespace operators {

74
/*
75 76 77 78 79 80 81
*  Pack input and output tensors into respective vectors with
*  consideration of varible X`s class type.
*  Input variable X is supported to be whether LoDTensor or
*  SelectedRows class type in this package function, once X
*  was SelectedRows type, a valid pointer x_for_selectedrows
*  is excepted to be passed in from op kernel for acquisition
*  of the valid address of LoDTensor created ahead in the function.
82
*/
83 84 85
template <typename OutT>
int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
                          std::vector<const framework::Tensor *> *ins,
86 87
                          std::vector<framework::Tensor *> *outs,
                          framework::Tensor *x_for_selectedrows = nullptr) {
88
  int axis = -1;
89 90 91 92 93
  auto x_var = ctx.InputVar("X");
  PADDLE_ENFORCE_NOT_NULL(
      x_var, platform::errors::InvalidArgument(
                 "Unable to get input Variable X, Variable name is %s.\n",
                 ctx.InputName("X")));
94
  auto *y = ctx.Input<framework::LoDTensor>("Y");
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  framework::Tensor *z;

  if (x_var->IsType<framework::LoDTensor>()) {
    auto *x = ctx.Input<framework::LoDTensor>("X");
    z = ctx.Output<framework::LoDTensor>("Out");
    ins->emplace_back(x);
  } else if (x_var->IsType<framework::SelectedRows>()) {
    PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
                      platform::errors::InvalidArgument(
                          "For elementwise_op, if X is Sparse, Y must be "
                          "scalar. But reveived the size of Y = %d.",
                          y->dims().size()));
    PADDLE_ENFORCE_NOT_NULL(
        x_for_selectedrows,
        platform::errors::InvalidArgument(
            "The parameter x_for_selectedrows is excepted to "
            "be valid, once input varible X`s class type is "
            "SelectedRows.\n"));
    auto &x_sele = x_var->Get<framework::SelectedRows>();
    auto out_sele = ctx.Output<framework::SelectedRows>("Out");
    *x_for_selectedrows = x_sele.value();
    out_sele->set_rows(x_sele.rows());
    out_sele->set_height(x_sele.height());
    out_sele->mutable_value()->Resize(x_sele.value().dims());
    out_sele->mutable_value()->mutable_data(ctx.GetPlace(),
                                            x_for_selectedrows->type());
    z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
    ins->emplace_back(x_for_selectedrows);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "X's type[%s] is not supported by elementwise_op. X's type should be "
        "LoDTensor or SelectedRows.",
        framework::ToTypeName(x_var->Type())));
  }
129
  z->mutable_data<OutT>(ctx.GetPlace());
130 131 132 133
  outs->emplace_back(z);

  if (y != nullptr) {
    ins->emplace_back(y);
134
    axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
135
  }
136
  return axis;
137 138
}

139 140
inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
                               const int *index_array) {
141
  return pten::GetElementwiseIndex(x_dims_array, max_dim, index_array);
142 143 144 145
}

inline void UpdateElementwiseIndexArray(const int *out_dims_array,
                                        const int max_dim, int *index_array) {
146
  pten::UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array);
147 148 149 150 151 152 153
}

inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
                                   const framework::DDim &y_dims,
                                   int *x_dims_array, int *y_dims_array,
                                   int *out_dims_array, const int max_dim,
                                   const int axis) {
154 155 156
  pten::general::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array,
                                        y_dims_array, out_dims_array, max_dim,
                                        axis);
157
}
158

159 160 161 162 163 164 165 166
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const framework::Tensor *x,
                               const framework::Tensor *y, framework::Tensor *z,
                               int *x_dims_array, int *y_dims_array,
                               int *out_dims_array, int max_dim,
                               const platform::CPUDeviceContext &ctx,
                               Functor func,
                               const bool is_xsize_larger = true) {
167 168 169
  pten::CommonForwardBroadcastCPU(x, y, z, x_dims_array, y_dims_array,
                                  out_dims_array, max_dim, ctx, func,
                                  is_xsize_larger);
170 171
}

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCPU(
    const framework::Tensor &x, const framework::Tensor &y,
    const framework::Tensor &out, const framework::Tensor &dout,
    framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array,
    int *y_dims_array, int *out_dims_array, int max_dim,
    const platform::CPUDeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) {
  std::vector<int> index_array(max_dim, 0);
  const T *x_data = x.data<T>();
  const T *y_data = y.data<T>();
  const T *out_data = out.data<T>();
  const T *dout_data = dout.data<T>();
  T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
  T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
  if (dx_data != nullptr) {
    memset(dx_data, 0, dx->numel() * sizeof(T));
  }
  if (dy_data != nullptr) {
    memset(dy_data, 0, dy->numel() * sizeof(T));
  }
  const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
                                       1, std::multiplies<int>());
  int x_index, y_index;
  for (int out_index = 0; out_index < out_size; ++out_index) {
    x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
    y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
    if (dx_data != nullptr) {
      dx_data[x_index] += dx_op(x_data[x_index], y_data[y_index],
                                out_data[out_index], dout_data[out_index]);
    }
    if (dy_data != nullptr) {
      dy_data[y_index] += dy_op(x_data[x_index], y_data[y_index],
                                out_data[out_index], dout_data[out_index]);
    }

    UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
  }
}

inline void ComputeBroadcastKernelSize(int *x_dims_array, int *out_dims_array,
                                       int *x_blocks, int *x_threads,
                                       int max_dim) {
  *x_blocks = 1;
  *x_threads = 1;
  for (int i = 0; i < max_dim; i++) {
    if (x_dims_array[i] == out_dims_array[i]) {
      *x_blocks *= x_dims_array[i];
    } else {
      *x_threads *= out_dims_array[i];
    }
  }
}

inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
                                               int *x_trans_indexs,
                                               const int max_dim,
                                               const int x_one_size) {
  int diff = max_dim - x_one_size;
  std::copy_n(x_one_indexs, x_one_size, x_trans_indexs + diff);
  int p = 0;
  int q = diff;
  for (int i = 0; i < max_dim; ++i) {
    if (q < max_dim && i == x_trans_indexs[q]) {
      ++q;
    } else {
      x_trans_indexs[p++] = i;
    }
  }
}

242
#if defined(__NVCC__) || defined(__HIPCC__)
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
  T val(0);
  if (is_xsize_larger) {
    do {
      int x_offset = i * w + j;
      if (dx) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy) {
        val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    do {
      int y_offset = i * w + j;
      if (dy) {
        dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }
      if (dx) {
        val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dx) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
}

// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
  __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];

  T val(0);
  size_t width_stride = gridDim.x * blockDim.x;
  size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
  size_t full_width =
      (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
  size_t full_height =
      (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
  if (is_xsize_larger) {
    for (int m = idx; m < full_width; m += width_stride) {
      sdata[threadIdx.y][threadIdx.x] = 0;
      for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
        int x_offset = n * w + m;
        if (dx && m < w && n < h) {
          dx[x_offset] =
              dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
        }
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
            sdata[threadIdx.y][threadIdx.x] += val;
          }
          __syncthreads();
        }
      }
      if (dy) {
        T my_val = sdata[threadIdx.x][threadIdx.y];
        for (int i = warpSize >> 1; i > 0; i >>= 1)
          my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
        __syncthreads();
        if ((threadIdx.x == 0)) {
          sdata[0][threadIdx.y] = my_val;
        }
        __syncthreads();
        if (threadIdx.y == 0 && m < w) {
          dy[m] = sdata[0][threadIdx.x];
        }
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int m = idx; m < full_width; m += width_stride) {
      sdata[threadIdx.y][threadIdx.x] = 0;
      for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
        int y_offset = n * w + m;
        if (dy && m < w && n < h) {
          dy[y_offset] =
              dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
        }
        if (dx) {
          if (m < w && n < h) {
            T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
            sdata[threadIdx.y][threadIdx.x] += val;
          }
          __syncthreads();
        }
      }
      if (dx) {
        T my_val = sdata[threadIdx.x][threadIdx.y];
        for (int i = warpSize >> 1; i > 0; i >>= 1)
          my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
        __syncthreads();
        if ((threadIdx.x == 0)) {
          sdata[0][threadIdx.y] = my_val;
        }
        __syncthreads();
        if (threadIdx.y == 0 && m < w) {
          dx[m] = sdata[0][threadIdx.x];
        }
      }
    }
  }
}

372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
template <typename T, typename DX_OP>
__global__ void CommonGradBroadcastCUDAKernel(
    const int *x_strides_array, const int *y_strides_array,
    const int *out_dims_array, const int *y_strides_order,
    const int *y_dims_order, const T *x, const T *y, const T *out,
    const T *dout, T *dx, int out_size, int max_dim, int thread_num,
    DX_OP dx_op) {
  T val(0);
  int i = blockIdx.x;
  int tid = threadIdx.x;
  for (int j = tid; j < thread_num; j += blockDim.x) {
    const int X_index = i * thread_num + j;
    int out_index = X_index;
    int C_index = 0;
    int B_index = i * thread_num + j;
    int remainder = 0;
#pragma unroll
    for (int d = max_dim - 1; d >= 0; --d) {
      GetDivMod(B_index, y_dims_order[d], &B_index, &remainder);
      C_index += remainder * y_strides_order[d];
    }
    int x_index = 0;
    int y_index = 0;
    int C_index_val = C_index;
#pragma unroll
    for (int d = max_dim - 1; d >= 0; --d) {
      GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder);
      x_index += remainder * x_strides_array[d];
      y_index += remainder * y_strides_array[d];
    }
    out_index = C_index;
    val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
  }
  val = paddle::platform::reduceSum(val, tid, thread_num);
  if (threadIdx.x == 0) {
    dx[i] = val;
  }
}

411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
template <typename T, typename DY_OP>
static __global__ void CommonGradBroadcast1CUDAKernelHeight(
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
  T val(0);

  if (is_y) {
    do {
      int out_offset = i * w + j;
      int x_offset = (i % x_h) * x_w + j % x_w;
      if (dy) {
        val += dy_op(x[x_offset], y[j], out[out_offset], dout[out_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    do {
      int out_offset = i * w + j;
      int y_offset = (i % x_h) * x_w + j % x_w;
      if (dy) {
        val += dy_op(x[j], y[y_offset], out[out_offset], dout[out_offset]);
      }
      i += ELEMWISE_MAX_BLOCK_DIM;
    } while (i < h);

    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  }
}

template <typename T, typename DY_OP>
static __global__ void FastCommonGradBroadcastCUDAKernelHeight(
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) {
  __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];

  T val(0);
  size_t width_stride = gridDim.x * blockDim.x;
  size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
  size_t full_width =
      (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
  size_t full_height =
      (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
  if (is_y) {
    for (int m = idx; m < full_width; m += width_stride) {
      sdata[threadIdx.y][threadIdx.x] = 0;
      for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
        int out_offset = n * w + m;
        int x_offset = (n % x_h) * x_w + m % x_w;
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]);
            sdata[threadIdx.y][threadIdx.x] += val;
          }
          __syncthreads();
        }
      }
      if (dy) {
        T my_val = sdata[threadIdx.x][threadIdx.y];
        for (int i = warpSize >> 1; i > 0; i >>= 1) {
          my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
        }
        __syncthreads();
        if ((threadIdx.x == 0)) {
          sdata[0][threadIdx.y] = my_val;
        }
        __syncthreads();
        if (threadIdx.y == 0 && m < w) {
          dy[m] = sdata[0][threadIdx.x];
        }
      }
    }
  } else {
    for (int m = idx; m < full_width; m += width_stride) {
      sdata[threadIdx.y][threadIdx.x] = 0;
      for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
        int out_offset = n * w + m;
        int y_offset = (n % x_h) * x_w + m % x_w;
        if (dy) {
          if (m < w && n < h) {
            T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]);
            sdata[threadIdx.y][threadIdx.x] += val;
          }
          __syncthreads();
        }
      }
      if (dy) {
        T my_val = sdata[threadIdx.x][threadIdx.y];
        for (int i = warpSize >> 1; i > 0; i >>= 1) {
          my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
        }
        __syncthreads();
        if ((threadIdx.x == 0)) {
          sdata[0][threadIdx.y] = my_val;
        }
        __syncthreads();
        if (threadIdx.y == 0 && m < w) {
          dy[m] = sdata[0][threadIdx.x];
        }
      }
    }
  }
}

template <typename T, typename DY_OP, typename DX_OP>
static __global__ void FastCommonGradBroadcastAllCUDAKernel(
    const T *x, const T *y, const T *out, const T *dout, int pre, int n,
    int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
  int tid = threadIdx.x;
  int bid = blockIdx.x;

  T val(0);
  if (is_xsize_larger) {
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int x_offset = b_i * n * post + i * post + b_j;
      int y_offset = b_i * post + b_j;
      if (dx) {
        dx[x_offset] =
            dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
      }
      if (dy) {
        val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
      }
    }
    if (dy) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
      val = paddle::platform::reduceSum(val, tid, h);
      if (tid == 0) {
        dy[bid] = val;
      }
    }
  } else {
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int y_offset = b_i * n * post + i * post + b_j;
      int x_offset = b_i * post + b_j;
      if (dy) {
        dy[y_offset] =
568
            dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
569 570
      }
      if (dx) {
571
        val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]);
572 573 574 575 576 577 578 579 580 581 582 583
      }
    }
    if (dx) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
      val = paddle::platform::reduceSum(val, tid, h);
      if (tid == 0) {
        dx[bid] = val;
      }
    }
  }
}

584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
template <typename T, typename OP>
static __global__ void FastCommonGradBroadcastOneCUDAKernel(
    const T *x, const T *y, const T *out, const T *dout, int pre, int n,
    int post, int y_pre, int y_n, int y_post, bool is_xsize, OP op, T *dd) {
  int tid = threadIdx.x;
  int bid = blockIdx.x;

  T val(0);
  if (is_xsize) {
    // do reduce for x
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int x_offset = b_i * n * post + b_j;
      int out_offset = b_i * n * post + i * post + b_j;

      // Get y pre rows id with x post and y_pre.
      int b_yi = bid / (post * y_pre);
      int b_yj = bid % y_post;
      int y_offset = b_yi * y_n + i * y_post + b_yj;

      if (dd) {
        val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
      }
    }
    if (dd) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
      val = paddle::platform::reduceSum(val, tid, h);
      if (tid == 0) {
        dd[bid] = val;
      }
    }
  } else {
    // do reduce for y
    for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) {
      int b_i = bid / post;
      int b_j = bid % post;
      int y_offset = b_i * n * post + b_j;
      int out_offset = b_i * n * post + i * post + b_j;

      int b_yi = bid / (post * y_pre);
      int b_yj = bid % y_post;
      int x_offset = b_yi * y_n + i * y_post + b_yj;

      if (dd) {
        val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]);
      }
    }
    if (dd) {
      int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n;
      val = paddle::platform::reduceSum(val, tid, h);
      if (tid == 0) {
        dd[bid] = val;
      }
    }
  }
}

642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
// Check input can be split into 2 parts
static inline bool SplitDims(const std::vector<int> &y_broadcast_pos,
                             int max_dim) {
  bool can_split_dim2 = true;
  // must at start or end.
  if (y_broadcast_pos[0] != 0 &&
      y_broadcast_pos[y_broadcast_pos.size() - 1] != max_dim - 1) {
    can_split_dim2 = false;
  } else {
    for (int i = 1; i < y_broadcast_pos.size(); ++i) {
      // dim must be continue
      if (y_broadcast_pos[i] != y_broadcast_pos[i - 1] + 1) {
        can_split_dim2 = false;
        break;
      }
    }
  }
  return can_split_dim2;
}

662 663 664 665 666 667 668 669 670 671
// Suppose only has contiguous dims
static inline bool CheckContiguousDims(const std::vector<int> &broadcast_pos) {
  for (int i = 1; i < broadcast_pos.size(); ++i) {
    if (broadcast_pos[i] != broadcast_pos[i - 1] + 1) {
      return false;
    }
  }
  return true;
}

672 673 674 675 676 677 678
template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCUDA(
    const framework::Tensor &x, const framework::Tensor &y,
    const framework::Tensor &out, const framework::Tensor &dout,
    framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array,
    int *y_dims_array, int *out_dims_array, int max_dim,
    const platform::CUDADeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) {
679
  const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
  auto cplace = platform::CPUPlace();
  const T *x_data = x.data<T>();
  const T *y_data = y.data<T>();
  const T *out_data = out.data<T>();
  const T *dout_data = dout.data<T>();
  T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
  T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());

  std::vector<int> x_one_indexs;
  std::vector<int> y_one_indexs;
  for (int i = 0; i < max_dim; i++) {
    if (x_dims_array[i] != y_dims_array[i]) {
      if (x_dims_array[i] == 1) {
        x_one_indexs.push_back(i);
      }
      if (y_dims_array[i] == 1) {
        y_one_indexs.push_back(i);
      }
    }
  }

  std::vector<int> x_trans_indexs(max_dim);
  std::vector<int> y_trans_indexs(max_dim);
  ComputeBroadcastTranspositionArray(x_one_indexs.data(), x_trans_indexs.data(),
                                     max_dim, x_one_indexs.size());
  ComputeBroadcastTranspositionArray(y_one_indexs.data(), y_trans_indexs.data(),
                                     max_dim, y_one_indexs.size());

  // compute array stride for cuda kernel;
  // e.g. x.dims=[2,3,4], x_stride=[12,4,1]
  std::vector<int> x_strides_array(max_dim);
  std::vector<int> y_strides_array(max_dim);
  std::vector<int> out_strides_array(max_dim);
  int x_stride = 1;
  int y_stride = 1;
  int z_stride = 1;
  for (int i = max_dim - 1; i >= 0; i--) {
    x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
    y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
    out_strides_array[i] = z_stride;
    x_stride *= x_dims_array[i];
    y_stride *= y_dims_array[i];
    z_stride *= out_dims_array[i];
  }

  std::vector<int> x_strides_order(max_dim);
  std::vector<int> y_strides_order(max_dim);
  std::vector<int> x_dims_order(max_dim);
  std::vector<int> y_dims_order(max_dim);
  for (int i = 0; i < max_dim; ++i) {
    x_strides_order[i] = out_strides_array[x_trans_indexs[i]];
    y_strides_order[i] = out_strides_array[y_trans_indexs[i]];
    x_dims_order[i] = out_dims_array[x_trans_indexs[i]];
    y_dims_order[i] = out_dims_array[y_trans_indexs[i]];
  }
735 736 737 738 739 740 741 742 743 744 745 746 747
  std::vector<int> x_broadcast_pos;
  std::vector<int> y_broadcast_pos;

  int bytes = max_dim * sizeof(int);

  for (int i = 0; i < max_dim; ++i) {
    if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) {
      x_broadcast_pos.emplace_back(i);
    }
    if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) {
      y_broadcast_pos.emplace_back(i);
    }
  }
748

749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855
  auto stream = ctx.stream();
  bool can_split_x = false;
  bool can_split_y = false;

  auto FastCommonCUDAF = [&](const std::vector<int> &broadcast_pos, bool is_y) {
    int h =
        std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(),
                        1, std::multiplies<int>());
    int w =
        std::accumulate(out_dims_array + broadcast_pos.size(),
                        out_dims_array + max_dim, 1, std::multiplies<int>());

    VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h
            << " is_y:" << is_y;

    int split_h;
    int split_w;
    int kh = h;
    int kw = w;

    if (is_y) {
      split_h =
          std::accumulate(x_dims_array, x_dims_array + broadcast_pos.size(), 1,
                          std::multiplies<int>());
      split_w =
          std::accumulate(x_dims_array + broadcast_pos.size(),
                          x_dims_array + max_dim, 1, std::multiplies<int>());

    } else {
      split_h =
          std::accumulate(y_dims_array, y_dims_array + broadcast_pos.size(), 1,
                          std::multiplies<int>());
      split_w =
          std::accumulate(y_dims_array + broadcast_pos.size(),
                          y_dims_array + max_dim, 1, std::multiplies<int>());
    }

    if (h > split_h) kh = split_h;
    if (w > split_w) kw = split_w;

    if (is_y) {
      if (w < 16 || h < 16) {
        int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
        int grid_size = w;
        CommonGradBroadcast1CUDAKernelHeight<<<grid_size, block_size, 0,
                                               stream>>>(
            x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw,
            is_y);
      } else {
        dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
        int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
        FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, block_size, 0,
                                                  stream>>>(
            x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw,
            is_y);
      }
    } else {
      if (w < 16 || h < 16) {
        int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
        int grid_size = w;
        CommonGradBroadcast1CUDAKernelHeight<<<grid_size, block_size, 0,
                                               stream>>>(
            x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw,
            is_y);
      } else {
        dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
        int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
        FastCommonGradBroadcastCUDAKernelHeight<<<grid_size, block_size, 0,
                                                  stream>>>(
            x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw,
            is_y);
      }
    }
  };

  auto FastBroadCastHeightCUDAF = [&](const std::vector<int> &broadcast_pos,
                                      bool x_large) {
    int h =
        std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(),
                        1, std::multiplies<int>());
    int w =
        std::accumulate(out_dims_array + broadcast_pos.size(),
                        out_dims_array + max_dim, 1, std::multiplies<int>());

    VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h;

    if (w < 16 || h < 16) {
      int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
      int grid_size = w;
      ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
          x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op,
          dx_data, dy_data);
    } else {
      dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
      int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
      FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0,
                                             stream>>>(
          x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op,
          dx_data, dy_data);
    }
  };

  auto FastBroadCastAllCUDAF = [&](const std::vector<int> &broadcast_pos,
                                   int max_dim, bool is_x_large) {
    int axis = broadcast_pos[0];
    int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
                              std::multiplies<int>());
856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
    int mid = 1;
    int post = 1;

    if (broadcast_pos.size() == 1) {
      mid = out_dims_array[axis];
      post =
          std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim,
                          1, std::multiplies<int>());
    } else {
      mid = std::accumulate(out_dims_array + axis,
                            out_dims_array + broadcast_pos.back() + 1, 1,
                            std::multiplies<int>());
      post =
          std::accumulate(out_dims_array + broadcast_pos.back() + 1,
                          out_dims_array + max_dim, 1, std::multiplies<int>());
    }
872 873 874 875 876 877 878 879 880 881 882 883

    VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid
            << " post:" << post;

    int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
    int grid_size = pre * post;

    FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
        x_data, y_data, out_data, dout_data, pre, mid, post, is_x_large, dx_op,
        dy_op, dx_data, dy_data);
  };

884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932
  auto FastBroadCastOneCUDAF = [&](const std::vector<int> &broadcast_pos,
                                   int max_dim, bool is_x) {
    int axis = broadcast_pos[0];
    int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1,
                              std::multiplies<int>());
    int mid = out_dims_array[axis];
    int post =
        std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1,
                        std::multiplies<int>());

    int k_pre;
    int k_mid;
    int k_post;

    if (is_x) {
      k_pre = std::accumulate(y_dims_array, y_dims_array + axis, 1,
                              std::multiplies<int>());
      k_mid = y_dims_array[axis];
      k_post = std::accumulate(y_dims_array + axis + 1, y_dims_array + max_dim,
                               1, std::multiplies<int>());
      int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
      int grid_size = pre * post;
      // we need to calc y offset with blockid, so do x_pre/y_pre to get left
      // size.
      if (k_pre != pre) k_pre = pre / k_pre;

      FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
                                             stream>>>(
          x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
          k_post, true, dx_op, dx_data);
    } else {
      k_pre = std::accumulate(x_dims_array, x_dims_array + axis, 1,
                              std::multiplies<int>());
      k_mid = x_dims_array[axis];
      k_post = std::accumulate(x_dims_array + axis + 1, x_dims_array + max_dim,
                               1, std::multiplies<int>());
      int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
      int grid_size = pre * post;
      if (k_pre != pre) k_pre = pre / k_pre;

      FastCommonGradBroadcastOneCUDAKernel<<<grid_size, block_size, 0,
                                             stream>>>(
          x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid,
          k_post, false, dy_op, dy_data);
    }
    VLOG(3) << "FastBroadCastOneCUDAF pre:" << pre << " mid:" << mid
            << " post:" << post;
  };

933 934 935 936
  // do fast elementwise if: 1. only one input need to do broadcast, we can
  // fallback
  // to old fast path.
  // 2. if both x and y need broadcast, then do it one by one.
937
  bool fast_broadcast = false;
938 939 940 941 942 943
  if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
    can_split_y = SplitDims(y_broadcast_pos, max_dim);
    if (can_split_y) {
      // only y need to do broadcast on h
      if (y_broadcast_pos[0] == 0) {
        FastBroadCastHeightCUDAF(y_broadcast_pos, true);
944
        fast_broadcast = true;
945
      }
946 947 948
    } else if (y_broadcast_pos.size() == 1 ||
               CheckContiguousDims(y_broadcast_pos)) {  // for only one dim and
                                                        // contiguous broadcast.
949 950
      // If cannot split,  which means input has 3 parts
      FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
951
      fast_broadcast = true;
952 953 954 955 956 957 958
    }
  } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
    // only x need broadcast
    can_split_x = SplitDims(x_broadcast_pos, max_dim);
    if (can_split_x) {
      if (x_broadcast_pos[0] == 0) {
        FastBroadCastHeightCUDAF(x_broadcast_pos, false);
959
        fast_broadcast = true;
960
      }
961 962
    } else if (x_broadcast_pos.size() == 1 ||
               CheckContiguousDims(x_broadcast_pos)) {
963
      FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
964
      fast_broadcast = true;
965 966 967 968
    }
  } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
    // do x and y broadcast each.
    can_split_y = SplitDims(y_broadcast_pos, max_dim);
969 970
    bool fast_broadcast_x = false;
    bool fast_broadcast_y = false;
971 972 973 974
    if (can_split_y) {
      // begin at start.
      if (y_broadcast_pos[0] == 0) {
        FastCommonCUDAF(y_broadcast_pos, true);
975
        fast_broadcast_y = true;
976
      }
977 978 979
    } else if (y_broadcast_pos.size() == 1) {
      FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
      can_split_y = true;
980
      fast_broadcast_y = true;
981 982 983 984 985
    }
    can_split_x = SplitDims(x_broadcast_pos, max_dim);
    if (can_split_x) {
      if (x_broadcast_pos[0] == 0) {
        FastCommonCUDAF(x_broadcast_pos, false);
986
        fast_broadcast_x = true;
987
      }
988 989 990
    } else if (x_broadcast_pos.size() == 1) {
      FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
      can_split_x = true;
991
      fast_broadcast_x = true;
992 993 994 995
    }
    VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
            << " can_split_x:" << can_split_x;
    // if both x and y into fast path then return
996 997 998 999
    if (fast_broadcast_x && fast_broadcast_y) {
      fast_broadcast = true;
    }
    if (can_split_y && can_split_x && fast_broadcast) return;
1000
  }
1001

1002
  // Should remove memory copy, use reg instead.
1003 1004 1005
  if (fast_broadcast) {
    return;
  }
1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
  int x_blocks = 0;
  int x_threads = 0;
  ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
                             &x_threads, max_dim);
  int y_blocks = 0;
  int y_threads = 0;
  ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks,
                             &y_threads, max_dim);

  auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
  int *x_strides_array_gpu =
      reinterpret_cast<int *>(x_strides_array_tmp->ptr());
  memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
               bytes, ctx.stream());

  auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
  int *y_strides_array_gpu =
      reinterpret_cast<int *>(y_strides_array_tmp->ptr());
  memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
               bytes, ctx.stream());

  auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
  int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
  memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
               ctx.stream());

  const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
                                       1, std::multiplies<int>());
  int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
  int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
1036
  if (dx) {
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
    auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
    int *x_strides_order_gpu =
        reinterpret_cast<int *>(x_strides_order_tmp->ptr());
    memory::Copy(gplace, x_strides_order_gpu, cplace, x_strides_order.data(),
                 bytes, ctx.stream());

    auto x_dims_order_tmp = memory::Alloc(ctx, bytes);
    int *x_dims_order_gpu = reinterpret_cast<int *>(x_dims_order_tmp->ptr());
    memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes,
                 ctx.stream());
    CommonGradBroadcastCUDAKernel<
        T, DX_OP><<<x_blocks, x_block_size, 0, ctx.stream()>>>(
        x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
        x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
        dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
  }
1053
  if (dy) {
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
    auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
    int *y_strides_order_gpu =
        reinterpret_cast<int *>(y_strides_order_tmp->ptr());
    memory::Copy(gplace, y_strides_order_gpu, cplace, y_strides_order.data(),
                 bytes, ctx.stream());

    auto y_dims_order_tmp = memory::Alloc(ctx, bytes);
    int *y_dims_order_gpu = reinterpret_cast<int *>(y_dims_order_tmp->ptr());
    memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes,
                 ctx.stream());
    CommonGradBroadcastCUDAKernel<
        T, DY_OP><<<y_blocks, y_block_size, 0, ctx.stream()>>>(
        x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
        y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data,
        dout_data, dy_data, out_size, max_dim, y_threads, dy_op);
  }
}

1072
#endif  // __NVCC__ or __HIPCC__
1073

1074
inline framework::DDim trim_trailing_singular_dims(
1075
    const framework::DDim &dims) {
1076
  return pten::general::trim_trailing_singular_dims(dims);
1077 1078
}

1079 1080
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
1081 1082
class TransformFunctor {
 public:
1083
  TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
1084 1085
                   framework::Tensor *z, const DeviceContext &ctx, Functor func,
                   const bool is_xsize_larger = true)
C
chengduoZH 已提交
1086 1087
      : x_(x->data<T>()),
        y_(y->data<T>()),
1088
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
1089 1090
        nx_(x->numel()),
        ctx_(ctx),
1091 1092 1093 1094 1095 1096
        func_(func),
        is_xsize_larger_(is_xsize_larger) {
    if (is_xsize_larger_ == false) {
      nx_ = y->numel();
    }
  }
C
chengduoZH 已提交
1097 1098

  inline void Run() const {
Q
QI JUN 已提交
1099
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
1100
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
1101 1102 1103
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
1104
    platform::Transform<DeviceContext> trans;
1105 1106
    if (is_xsize_larger_) {
      trans(ctx_, x_, x_ + nx_,
1107 1108
            pten::general::RowwiseTransformIterator<T, DeviceContext>(y_, n),
            z_, func_);
1109 1110
    } else {
      trans(ctx_, y_, y_ + nx_,
1111 1112
            pten::general::RowwiseTransformIterator<T, DeviceContext>(x_, n),
            z_, func_);
1113
    }
C
chengduoZH 已提交
1114 1115 1116
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
1117
    platform::Transform<DeviceContext> trans;
1118 1119
    if (is_xsize_larger_) {
      trans(ctx_, x_, x_ + nx_,
1120 1121 1122
            pten::general::MidWiseTransformIterator<T, DeviceContext>(y_, n,
                                                                      post),
            z_, func_);
1123 1124
    } else {
      trans(ctx_, y_, y_ + nx_,
1125 1126 1127
            pten::general::MidWiseTransformIterator<T, DeviceContext>(x_, n,
                                                                      post),
            z_, func_);
1128 1129 1130
    }
  }

C
chengduoZH 已提交
1131
 private:
1132 1133 1134
  const T *x_;
  const T *y_;
  OutType *z_;
C
chengduoZH 已提交
1135
  int64_t nx_;
1136
  const DeviceContext &ctx_;
C
chengduoZH 已提交
1137
  Functor func_;
1138
  bool is_xsize_larger_;
C
chengduoZH 已提交
1139 1140
};

Y
Yu Yang 已提交
1141 1142
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
1143 1144 1145 1146
  const T *x_;
  const T *y_;
  const T *out_;
  const T *dout_;
Y
Yu Yang 已提交
1147 1148 1149 1150 1151 1152

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
1153
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
1154 1155 1156 1157 1158
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
1159 1160
  T *dx_;
  T *dy_;
Y
Yu Yang 已提交
1161 1162 1163
};

template <typename T, typename DX_OP, typename DY_OP>
1164
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
1165 1166
                                      const T *dout, int h, int w,
                                      bool is_xsize_larger, DX_OP dx_op,
1167
                                      DY_OP dy_op, T *dx, T *dy) {
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
  if (is_xsize_larger) {
    for (int i = 0; i < h; ++i) {
      for (int j = 0; j < w; ++j) {
        int x_offset = i * w + j;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
Y
Yu Yang 已提交
1184
      }
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int i = 0; i < h; ++i) {
      for (int j = 0; j < w; ++j) {
        int y_offset = i * w + j;
        if (dy != nullptr) {
          dy[y_offset] =
              dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
        }
        if (dx != nullptr) {
          T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
          if (i == 0) {
            dx[j] = tmp;
          } else {
            dx[j] += tmp;
          }
Y
Yu Yang 已提交
1201 1202 1203 1204 1205
        }
      }
    }
  }
}
1206

1207
#if defined(__NVCC__) || defined(__HIPCC__)
1208

Y
Yu Yang 已提交
1209
template <typename T, typename DX_OP, typename DY_OP>
1210
static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x,
1211
                                       const T *y, const T *out, const T *dout,
1212 1213
                                       int h, int w, bool is_xsize_larger,
                                       DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
1214 1215 1216 1217 1218 1219
  // For small case use 1D block
  constexpr int half_walf = 16;
  if (w < half_walf || h < half_walf) {
    int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
    int gird_size = w;
    ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
1220
        x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
1221 1222 1223 1224 1225
  } else {
    // suppose perfoemance improves with h increased.
    dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
    int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
    FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
1226
        x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
1227
  }
Y
Yu Yang 已提交
1228 1229 1230 1231 1232
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
1233 1234
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
                                      const T *dout, int pre, int n, int post,
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
                                      bool is_xsize_larger, DX_OP dx_op,
                                      DY_OP dy_op, T *dx, T *dy) {
  if (is_xsize_larger) {
    for (int i = 0; i < pre; ++i) {
      for (int j = 0; j < n; ++j) {
        for (int k = 0; k < post; ++k) {
          int x_offset = i * n * post + j * post + k;
          if (dx != nullptr) {
            dx[x_offset] =
                dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          }
          if (dy != nullptr) {
            T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
            if (i == 0 && k == 0) {
              dy[j] = tmp;
            } else {
              dy[j] += tmp;
            }
          }
Y
Yu Yang 已提交
1254
        }
1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272
      }
    }
  } else {  // x.dims < y.dims, broadcast for x.
    for (int i = 0; i < pre; ++i) {
      for (int j = 0; j < n; ++j) {
        for (int k = 0; k < post; ++k) {
          int y_offset = i * n * post + j * post + k;
          if (dy != nullptr) {
            dy[y_offset] =
                dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
          }
          if (dx != nullptr) {
            T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
            if (i == 0 && k == 0) {
              dx[j] = tmp;
            } else {
              dx[j] += tmp;
            }
Y
Yu Yang 已提交
1273 1274 1275 1276 1277 1278 1279
          }
        }
      }
    }
  }
}

1280
#if defined(__NVCC__) || defined(__HIPCC__)
Y
Yu Yang 已提交
1281 1282
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
1283
    const T *x, const T *y, const T *out, const T *dout, int pre, int n,
1284
    int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
1285 1286 1287
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduo 已提交
1288
  T val(0);
Y
Yu Yang 已提交
1289 1290
  int ttid = tid;

1291 1292 1293 1294 1295
  if (is_xsize_larger) {
    while (true) {
      int i = ttid / post;
      int k = ttid % post;
      if (i >= pre) break;
Y
Yu Yang 已提交
1296

1297
      int x_offset = i * n * post + j * post + k;
Y
Yu Yang 已提交
1298

1299 1300 1301 1302 1303 1304 1305 1306 1307
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }

      if (dy != nullptr) {
        val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }

      ttid += ELEMWISE_MAX_BLOCK_DIM;
Y
Yu Yang 已提交
1308 1309
    }

1310 1311 1312 1313 1314 1315 1316
    if (dy) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
Y
Yu Yang 已提交
1317
    }
1318 1319 1320 1321 1322
  } else {  // x.dims < y.dims, broadcast for x.
    while (true) {
      int i = ttid / post;
      int k = ttid % post;
      if (i >= pre) break;
Y
Yu Yang 已提交
1323

1324
      int y_offset = i * n * post + j * post + k;
Y
Yu Yang 已提交
1325

1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343
      if (dy != nullptr) {
        dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }

      if (dx != nullptr) {
        val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
      }

      ttid += ELEMWISE_MAX_BLOCK_DIM;
    }

    if (dx) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
Y
Yu Yang 已提交
1344 1345 1346 1347 1348
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
1349
static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x,
1350
                                       const T *y, const T *out, const T *dout,
1351 1352
                                       int pre, int n, int post,
                                       bool is_xsize_larger, DX_OP dx_op,
1353
                                       DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
1354 1355
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
1356
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
1357
      x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
1358 1359 1360 1361
}

#endif

1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void CommonElementwiseBroadcastBackward(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
    const framework::DDim &y_dims, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
  int max_dim = std::max(x_dims.size(), y_dims.size());
  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  std::vector<int> x_dims_array(max_dim);
  std::vector<int> y_dims_array(max_dim);
  std::vector<int> out_dims_array(max_dim);
  GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
                         y_dims_array.data(), out_dims_array.data(), max_dim,
                         axis);
  // for inplace strategy. memset will make dx and dout clear and get wrong
  // result.
1379
  if (dx && dx->IsSharedBufferWith(dout)) {
1380 1381
    dx->clear();
    dx->mutable_data<T>(x_dims, ctx.GetPlace());
1382 1383
  }

1384 1385 1386 1387
  VLOG(3) << "CommonElementwiseBroadcastBackward xdims:"
          << framework::make_ddim(x_dims_array)
          << " ydim:" << framework::make_ddim(y_dims_array);

1388
  if (platform::is_gpu_place(ctx.GetPlace())) {
1389
#if defined(__NVCC__) || defined(__HIPCC__)
1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401
    CommonGradBroadcastCUDA<T, DX_OP, DY_OP>(
        x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
        out_dims_array.data(), max_dim,
        ctx.template device_context<platform::CUDADeviceContext>(), dx_op,
        dy_op);
#endif
  } else {
    CommonGradBroadcastCPU<T, DX_OP, DY_OP>(
        x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
        out_dims_array.data(), max_dim,
        ctx.template device_context<platform::CPUDeviceContext>(), dx_op,
        dy_op);
1402 1403 1404
  }
}

1405 1406
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
1407 1408 1409 1410 1411
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
1412
  size_t N = static_cast<size_t>(framework::product(x_dim));
D
dzhwinter 已提交
1413
#if !defined(_WIN32)
1414 1415
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
D
dzhwinter 已提交
1416 1417 1418 1419
#else
  platform::ForRange<DeviceContext> for_range(
      ctx.device_context<DeviceContext>(), N);
#endif  // !_WIN32
1420 1421 1422 1423 1424 1425 1426 1427
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
1428 1429
    const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
    const framework::DDim &y_dims, const framework::Tensor &x,
1430 1431 1432
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
1433
  bool is_xsize_larger = true;
1434

1435 1436 1437 1438 1439
  int max_dim = x_dims.size();
  if (x_dims.size() < y_dims.size()) {
    is_xsize_larger = false;
    max_dim = y_dims.size();
  }
1440

1441
  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
1442 1443 1444 1445 1446 1447 1448 1449 1450
  PADDLE_ENFORCE_GE(
      axis, 0,
      platform::errors::InvalidArgument(
          "Axis should be great than or equal to 0, but received axis is %d.",
          axis));
  PADDLE_ENFORCE_LT(axis, max_dim,
                    platform::errors::InvalidArgument(
                        "Axis should be less than %d, but received axis is %d.",
                        max_dim, axis));
1451 1452 1453 1454 1455

  int pre, n, post, is_run_common_broadcast, axis_trim = 0;
  if (is_xsize_larger) {
    auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
    axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
1456 1457
    pten::general::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n,
                                &post, &is_run_common_broadcast);
1458 1459 1460
  } else {
    auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
    axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
1461 1462
    pten::general::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n,
                                &post, &is_run_common_broadcast);
1463 1464 1465 1466 1467 1468 1469 1470
  }
  // special case for common backward implementation.
  if (is_run_common_broadcast) {
    CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    return;
  }
  if (post == 1) {
1471
    if (platform::is_gpu_place(ctx.GetPlace())) {
1472
#if defined(__NVCC__) || defined(__HIPCC__)
1473 1474
      ElemwiseGradBroadcast1CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
1475 1476
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, is_xsize_larger,
          dx_op, dy_op,
1477 1478 1479 1480 1481
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast1CPU(
1482
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
1483
          is_xsize_larger, dx_op, dy_op,
1484
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
1485 1486 1487 1488
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
1489
#if defined(__NVCC__) || defined(__HIPCC__)
1490 1491
      ElemwiseGradBroadcast2CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
1492 1493 1494
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
          is_xsize_larger, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
1495 1496 1497 1498 1499
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast2CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
1500
          is_xsize_larger, dx_op, dy_op,
1501 1502 1503 1504 1505 1506
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

1507 1508 1509 1510 1511 1512 1513
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
void CommonElementwiseBroadcastForward(
    const framework::ExecutionContext &ctx, const framework::Tensor *x,
    const framework::Tensor *y, framework::Tensor *z,
    const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
    int axis, const bool is_xsize_larger = true) {
1514 1515 1516 1517 1518 1519 1520 1521
  z->mutable_data<OutType>(ctx.GetPlace());
  auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
  auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
  auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
  const auto &dev_ctx = ctx.template device_context<DeviceContext>();
  pten::CommonElementwiseBroadcastForward(dev_ctx, *pt_x.get(), *pt_y.get(),
                                          pt_z.get(), x_dims, y_dims, func,
                                          axis, is_xsize_larger);
1522 1523
}

Y
Yu Yang 已提交
1524
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
1525 1526 1527 1528 1529
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
                         const framework::Tensor &x, const framework::Tensor &y,
                         const framework::Tensor &out,
                         const framework::Tensor &dout, int axis,
                         framework::Tensor *dx, framework::Tensor *dy,
Y
Yu Yang 已提交
1530
                         DX_OP dx_op, DY_OP dy_op) {
1531 1532
  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
Y
Yu Yang 已提交
1533
  if (x.dims() == y.dims()) {
1534 1535
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
1536
  } else {
1537 1538 1539 1540 1541 1542 1543 1544 1545 1546
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
1547 1548 1549 1550 1551 1552
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
                                 const framework::Tensor &x,
                                 const framework::Tensor &y,
                                 const framework::Tensor &out,
                                 const framework::Tensor &dout, int axis,
                                 framework::Tensor *dx, framework::Tensor *dy,
1553
                                 DX_OP dx_op, DY_OP dy_op) {
1554 1555 1556
  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
  if (x.dims() == y.dims()) {
1557
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
1558
        ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op);
1559
  } else {
1560 1561
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op);
1562 1563
  }
}
F
fengjiayi 已提交
1564

1565 1566 1567 1568
// It is a common implementation to compute binary calculation with the support
// of broadcast, supporting both CPU and GPU.
// - CPU implementation cannot support the case when x needs broadcast, thus
//   this function need to be called with XxxFunctor and XxxInverseFunctor,
1569
//   like AddFunctor and InverseAddFunctor.
1570 1571 1572 1573
// - GPU implementation supports all the broadcast cases, thus there is no need
//   to define and call with XxxInverseFunctor.
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
1574 1575
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
1576 1577 1578 1579
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
                          const framework::Tensor *x,
                          const framework::Tensor *y, int axis, Functor func,
                          framework::Tensor *z) {
1580 1581 1582 1583
  z->mutable_data<OutType>(ctx.GetPlace());
  auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
  auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
  auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597
  if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
    std::vector<const framework::Tensor *> ins = {x, y};
    std::vector<framework::Tensor *> outs = {z};
    z->mutable_data<OutType>(ctx.GetPlace());

    const auto &dev_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();
    LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, OutType>(
        dev_ctx, ins, &outs, axis, func);
#endif
    return;
  }

1598 1599 1600 1601
  const auto &dev_ctx =
      ctx.template device_context<platform::CPUDeviceContext>();
  pten::ElementwiseCompute<Functor, T, OutType>(
      dev_ctx, *pt_x.get(), *pt_y.get(), axis, func, pt_z.get());
F
fengjiayi 已提交
1602 1603
}

1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
    T y_val = y_[i];
    T x_val = x_[i];
    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
      intermediate_out_[i] = intermeidiate_out;
      out_[i] =
          compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out_[i] = compound_functor_.GetOut(x_val, y_val);
    }
  }

  const T *x_;
  const T *y_;
  CompoundFunctor compound_functor_;
  T *out_;
  T *intermediate_out_;
};

// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
                                             CompoundFunctor compound_functor,
                                             int h, int w, T *out,
                                             T *intermediate_out) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      T y_val = BcastY ? y[j] : y[offset];
      T x_val = BcastY ? x[offset] : x[j];
      int64_t intermediate_out_offset;
      if (KeepIntermediateOut) {
        T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

        if (SameShapeOfIntermediateOutAndOut) {
          // for the case of f1(f2(x, y))
          intermediate_out_offset = offset;
        } else if (BcastY) {
          intermediate_out_offset = j;
        } else {
          intermediate_out_offset = offset;
        }

        intermediate_out[intermediate_out_offset] = intermeidiate_out;
        out[offset] =
            compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
      } else {
        out[offset] = compound_functor.GetOut(x_val, y_val);
      }
    }
  }
}

// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
                                             int n, int post,
                                             CompoundFunctor compound_functor,
                                             T *out, T *intermediate_out) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        T y_val = BcastY ? y[j] : y[offset];
        T x_val = BcastY ? x[offset] : x[j];
        int64_t intermediate_out_offset;

        if (KeepIntermediateOut) {
          T intermeidiate_out =
              compound_functor.GetIntermediateOut(x_val, y_val);

          if (SameShapeOfIntermediateOutAndOut) {
            // for the case of f1(f2(x, y))
            intermediate_out_offset = offset;
          } else if (BcastY) {
            intermediate_out_offset = j;
          } else {
            intermediate_out_offset = offset;
          }

          intermediate_out[intermediate_out_offset] = intermeidiate_out;
          out[offset] = compound_functor.GetOutUseIntermediateOut(
              x_val, intermeidiate_out);
        } else {
          out[offset] = compound_functor.GetOut(x_val, y_val);
        }
      }
    }
  }
}

1711
#if defined(__NVCC__) || defined(__HIPCC__)
1712 1713 1714 1715 1716
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
    const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
    T *out, T *intermediate_out) {
1717 1718
  int i = blockIdx.x;
  int j = threadIdx.x;
1719

1720
  while (j < w) {
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745
    int offset = i * w + j;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

1746
    j += ELEMWISE_MAX_BLOCK_DIM;
1747 1748 1749 1750 1751
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
1752
static void FusedElemwiseAndActBroadcast1CUDA(gpuStream_t stream, const T *x,
1753 1754 1755 1756
                                              const T *y,
                                              CompoundFunctor compound_functor,
                                              int h, int w, T *out,
                                              T *intermediate_out) {
1757 1758
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, w);
  int gird_size = h;
1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808
  FusedElemwiseAndActBroadcast1CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, h, w, compound_functor, out, intermediate_out);
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
    const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
    int post, T *out, T *intermediate_out) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

  while (true) {
    int i = tid / post;
    int k = tid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

    tid += ELEMWISE_MAX_BLOCK_DIM;
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
1809
static void FusedElemwiseAndActBroadcast2CUDA(gpuStream_t stream, const T *x,
1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857
                                              const T *y, int pre, int n,
                                              int post,
                                              CompoundFunctor compound_functor,
                                              T *out, T *intermediate_out) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;

  FusedElemwiseAndActBroadcast2CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, compound_functor, pre, n, post, out, intermediate_out);
}

#endif

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::Tensor &x, const framework::Tensor &y,
    CompoundFunctor compound_functor, framework::Tensor *out,
    framework::Tensor *intermediate_out) {
  size_t N = static_cast<size_t>(framework::product(x_dim));

  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);

  for_range(
      FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
          x.data<T>(), y.data<T>(), compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool BcastY, bool KeepIntermediateOut,
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
    const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
    framework::Tensor *out, framework::Tensor *intermediate_out) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

1858
  int pre, n, post, is_run_common_broadcast;
1859 1860
  pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
                              &is_run_common_broadcast);
1861 1862 1863 1864
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
1865
#if defined(__NVCC__) || defined(__HIPCC__)
1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887
      FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
1888
#if defined(__NVCC__) || defined(__HIPCC__)
1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912
      FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

// --- backward
C
chengduo 已提交
1913 1914
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut>
1915 1916
struct FusedElemwiseAndActGradNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
1917 1918 1919
    T zero = static_cast<T>(0);
    T x_val = (x_ == nullptr) ? zero : x_[i];
    T y_val = (y_ == nullptr) ? zero : y_[i];
1920 1921 1922 1923 1924
    T out_val = out_[i];
    T dout_val = dout_[i];
    T intermediate_out_val = UseIntermediateOut
                                 ? intermediate_out_[i]
                                 : dx_op_.GetIntermediateOut(x_val, y_val);
1925
    if (dx_ != nullptr) {
1926 1927
      dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
                                         out_val, dout_val);
1928 1929
    }
    if (dy_ != nullptr) {
1930 1931
      dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
                                         out_val, dout_val);
C
chengduo 已提交
1932 1933
    }
    if (dintermediate_ != nullptr) {
1934 1935
      dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
          x_val, intermediate_out_val, out_val, dout_val);
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945
    }
  }

  const T *x_;
  const T *y_;
  const T *intermediate_out_;
  const T *out_;
  const T *dout_;
  DX_OP dx_op_;
  DY_OP dy_op_;
C
chengduo 已提交
1946
  DIntermediate_OP dintermediate_op_;
1947 1948
  T *dx_;
  T *dy_;
C
chengduo 已提交
1949
  T *dintermediate_;
1950 1951 1952
};

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
1953
          typename DIntermediate_OP, bool UseIntermediateOut>
1954 1955 1956 1957 1958
void FusedElemwiseAndActGradComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
C
chengduo 已提交
1959 1960 1961
    framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
1962 1963 1964
  size_t N = static_cast<size_t>(framework::product(x_dim));
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977
  const T *x_data = nullptr;
  const T *y_data = nullptr;
  if (x->IsInitialized()) x_data = x->data<T>();
  if (y->IsInitialized()) y_data = y->data<T>();

  for_range(FusedElemwiseAndActGradNoBroadcast<
            T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>{
      x_data, y_data, intermediate_out ? intermediate_out->data<T>() : nullptr,
      out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
      dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                               ctx.GetPlace())});
1978 1979
}

C
chengduo 已提交
1980 1981 1982 1983 1984 1985 1986
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1987
  int64_t tmp_out_idx, x_idx, y_idx;
1988
  T zero = static_cast<T>(0);
1989 1990 1991 1992 1993 1994 1995
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      tmp_out_idx = BcastY ? j : offset;
      y_idx = BcastY ? j : offset;
      x_idx = BcastY ? offset : j;
1996 1997
      T x_val = (x == nullptr) ? zero : x[x_idx];
      T y_val = (y == nullptr) ? zero : y[y_idx];
1998 1999 2000 2001 2002 2003 2004

      if (SameShapeOfIntermediateOutAndOut) {
        tmp_out_idx = offset;
      }

      if (dx != nullptr) {
        T tmp = UseIntermediateOut
2005
                    ? dx_op.UseIntermediateOut(x_val, y_val,
C
chengduo 已提交
2006 2007
                                               intermediate_out[tmp_out_idx],
                                               out[offset], dout[offset])
2008
                    : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021

        if (BcastY) {
          dx[x_idx] = tmp;
        } else {
          if (i == 0) {
            dx[x_idx] = tmp;
          } else {
            dx[x_idx] += tmp;
          }
        }
      }
      if (dy != nullptr) {
        T tmp = UseIntermediateOut
2022
                    ? dy_op.UseIntermediateOut(x_val, y_val,
C
chengduo 已提交
2023 2024
                                               intermediate_out[tmp_out_idx],
                                               out[offset], dout[offset])
2025
                    : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2026 2027 2028 2029 2030 2031 2032 2033 2034 2035
        if (BcastY) {
          if (i == 0) {
            dy[y_idx] = tmp;
          } else {
            dy[y_idx] += tmp;
          }
        } else {
          dy[y_idx] = tmp;
        }
      }
C
chengduo 已提交
2036 2037 2038
      if (d_intermediate != nullptr) {
        T tmp = UseIntermediateOut
                    ? dintermediate_op.UseIntermediateOut(
2039
                          x_val, intermediate_out[tmp_out_idx], out[offset],
C
chengduo 已提交
2040
                          dout[offset])
2041 2042
                    : dintermediate_op.Recompute(x_val, y_val, out[offset],
                                                 dout[i]);
C
chengduo 已提交
2043 2044 2045 2046 2047 2048 2049 2050 2051 2052
        if (SameShapeOfIntermediateOutAndOut) {
          d_intermediate[tmp_out_idx] = tmp;
        } else {
          if (i == 0) {
            d_intermediate[tmp_out_idx] = tmp;
          } else {
            d_intermediate[tmp_out_idx] += tmp;
          }
        }
      }
2053 2054 2055 2056
    }
  }
}

C
chengduo 已提交
2057 2058 2059 2060 2061 2062 2063
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2064
  int64_t tmp_out_idx, x_idx, y_idx;
2065
  T zero = static_cast<T>(0);
2066 2067 2068 2069 2070 2071 2072 2073 2074
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        tmp_out_idx = BcastY ? j : offset;
        y_idx = BcastY ? j : offset;
        x_idx = BcastY ? offset : j;

2075 2076 2077
        T x_val = (x == nullptr) ? zero : x[x_idx];
        T y_val = (y == nullptr) ? zero : y[y_idx];

2078 2079 2080 2081 2082
        if (SameShapeOfIntermediateOutAndOut) {
          tmp_out_idx = offset;
        }

        if (dx != nullptr) {
2083 2084 2085 2086 2087 2088
          T tmp =
              UseIntermediateOut
                  ? dx_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100

          if (BcastY) {
            dx[x_idx] = tmp;
          } else {
            if (i == 0 && k == 0) {
              dx[x_idx] = tmp;
            } else {
              dx[x_idx] += tmp;
            }
          }
        }
        if (dy != nullptr) {
2101 2102 2103 2104 2105 2106
          T tmp =
              UseIntermediateOut
                  ? dy_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2107 2108 2109 2110 2111 2112 2113 2114 2115 2116
          if (BcastY) {
            if (i == 0 && k == 0) {
              dy[y_idx] = tmp;
            } else {
              dy[y_idx] += tmp;
            }
          } else {
            dy[y_idx] = tmp;
          }
        }
C
chengduo 已提交
2117 2118 2119
        if (d_intermediate != nullptr) {
          T tmp = UseIntermediateOut
                      ? dintermediate_op.UseIntermediateOut(
2120 2121 2122 2123
                            x_val, intermediate_out[tmp_out_idx], out[offset],
                            dout[offset])
                      : dintermediate_op.Recompute(x_val, y_val, out[offset],
                                                   dout[i]);
C
chengduo 已提交
2124 2125 2126 2127 2128 2129 2130 2131 2132 2133
          if (SameShapeOfIntermediateOutAndOut) {
            d_intermediate[tmp_out_idx] = tmp;
          } else {
            if (i == 0) {
              d_intermediate[tmp_out_idx] = tmp;
            } else {
              d_intermediate[tmp_out_idx] += tmp;
            }
          }
        }
2134 2135 2136 2137 2138
      }
    }
  }
}

2139
#if defined(__NVCC__) || defined(__HIPCC__)
C
chengduo 已提交
2140 2141 2142
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
2143 2144
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
C
chengduo 已提交
2145 2146
    const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2147 2148 2149 2150 2151 2152
  __shared__ T sdata[BLOCK_Y][BLOCK_X];
  size_t idx = threadIdx.x + BLOCK_X * blockIdx.x;
  size_t width_stride = gridDim.x * BLOCK_X;

  size_t full_w = ROUNDUP(w, BLOCK_X);

2153
  T zero = static_cast<T>(0);
2154

2155 2156 2157 2158 2159
  for (size_t j = idx; j < full_w; j += width_stride) {
    T val(0), inter_val(0);
    if (j < w) {
      for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) {
        size_t offset = i * w + j;
2160

2161 2162 2163 2164 2165
        size_t tmp_out_idx = BcastY ? j : offset;
        size_t y_idx = BcastY ? j : offset;
        size_t x_idx = BcastY ? offset : j;
        T x_val = (x == nullptr) ? zero : x[x_idx];
        T y_val = (y == nullptr) ? zero : y[y_idx];
2166

2167 2168 2169
        if (SameShapeOfIntermediateOutAndOut) {
          tmp_out_idx = offset;
        }
2170

2171 2172 2173
        if (dx != nullptr) {
          T tmp =
              UseIntermediateOut
2174 2175 2176 2177
                  ? dx_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2178

2179 2180 2181 2182 2183 2184 2185 2186 2187
          if (BcastY) {
            dx[x_idx] = tmp;
          } else {
            val += tmp;
          }
        }
        if (dy != nullptr) {
          T tmp =
              UseIntermediateOut
2188 2189 2190 2191
                  ? dy_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210
          if (BcastY) {
            val += tmp;
          } else {
            dy[y_idx] = tmp;
          }
        }
        if (d_intermediate != nullptr) {
          T tmp = UseIntermediateOut
                      ? dintermediate_op.UseIntermediateOut(
                            y[y_idx], intermediate_out[tmp_out_idx],
                            out[offset], dout[offset])
                      : dintermediate_op.Recompute(x_val, y_val, out[offset],
                                                   dout[offset]);
          if (SameShapeOfIntermediateOutAndOut) {
            d_intermediate[tmp_out_idx] = tmp;
          } else {
            inter_val += tmp;
          }
        }
C
chengduo 已提交
2211 2212
      }
    }
2213

2214 2215 2216 2217 2218 2219 2220 2221 2222
    // transpose, for ReduceSum with wrap
    sdata[threadIdx.y][threadIdx.x] = val;
    __syncthreads();
    val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
    for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
      // reduce sum with wrap
      val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
    }
2223

2224 2225 2226 2227
    size_t idx_j = j + threadIdx.y;
    if (BcastY) {
      if (dy) {
        if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
2228
      }
2229 2230 2231
    } else {
      if (dx) {
        if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
2232 2233
      }
    }
2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245

    if (!SameShapeOfIntermediateOutAndOut) {
      if (d_intermediate) {
        sdata[threadIdx.y][threadIdx.x] = inter_val;
        __syncthreads();
        inter_val = sdata[threadIdx.x][threadIdx.y];
#pragma unroll
        for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
          // reduce sum with wrap
          inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
        }
        if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
C
chengduo 已提交
2246 2247
      }
    }
2248
  }  // end for
2249 2250
}

C
chengduo 已提交
2251 2252 2253 2254
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266
    const framework::ExecutionContext &ctx, const T *x, const T *y,
    const T *intermediate_out, const T *out, const T *dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
    T *d_intermediate) {
  gpuStream_t stream = ctx.cuda_device_context().stream();

  dim3 blocks(BLOCK_X, BLOCK_Y);
  int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
  int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1);
  int theory_block = (w + BLOCK_X - 1) / BLOCK_X;
  dim3 grids(std::min(theory_block, max_blocks));

2267
  FusedElemwiseAndActGradBroadcast1CUDAKernel<
C
chengduo 已提交
2268
      T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
2269
      SameShapeOfIntermediateOutAndOut><<<grids, blocks, 0, stream>>>(
C
chengduo 已提交
2270 2271
      x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
      dx, dy, d_intermediate);
2272 2273
}

C
chengduo 已提交
2274 2275 2276
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
2277 2278
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
C
chengduo 已提交
2279 2280
    const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2281 2282 2283
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduo 已提交
2284
  T val(0), inter_val(0);
2285 2286
  int ttid = tid;
  int64_t tmp_out_idx, x_idx, y_idx;
2287
  T zero = static_cast<T>(0);
2288 2289 2290 2291 2292 2293 2294 2295 2296 2297
  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    tmp_out_idx = BcastY ? j : offset;
    y_idx = BcastY ? j : offset;
    x_idx = BcastY ? offset : j;
2298 2299
    T x_val = (x == nullptr) ? zero : x[x_idx];
    T y_val = (y == nullptr) ? zero : y[y_idx];
2300 2301 2302 2303 2304 2305

    if (SameShapeOfIntermediateOutAndOut) {
      tmp_out_idx = offset;
    }

    if (dx != nullptr) {
2306 2307 2308 2309 2310
      T tmp = UseIntermediateOut
                  ? dx_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2311 2312 2313 2314 2315 2316 2317 2318

      if (BcastY) {
        dx[x_idx] = tmp;
      } else {
        val += tmp;
      }
    }
    if (dy != nullptr) {
2319 2320 2321 2322 2323
      T tmp = UseIntermediateOut
                  ? dy_op.UseIntermediateOut(x_val, y_val,
                                             intermediate_out[tmp_out_idx],
                                             out[offset], dout[offset])
                  : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2324 2325 2326 2327 2328 2329
      if (BcastY) {
        val += tmp;
      } else {
        dy[y_idx] = tmp;
      }
    }
C
chengduo 已提交
2330 2331 2332
    if (d_intermediate != nullptr) {
      T tmp = UseIntermediateOut
                  ? dintermediate_op.UseIntermediateOut(
2333
                        y_val, intermediate_out[tmp_out_idx], out[offset],
C
chengduo 已提交
2334
                        dout[offset])
2335
                  : dintermediate_op.Recompute(x_val, y_val, out[offset],
C
chengduo 已提交
2336 2337 2338 2339 2340 2341 2342
                                               dout[offset]);
      if (SameShapeOfIntermediateOutAndOut) {
        d_intermediate[tmp_out_idx] = tmp;
      } else {
        inter_val += tmp;
      }
    }
2343 2344 2345
    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

C
chengduo 已提交
2346 2347
  int h = pre * post;
  h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362
  if (BcastY) {
    if (dy) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    if (dx) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
C
chengduo 已提交
2363 2364 2365 2366 2367 2368 2369 2370
  if (!SameShapeOfIntermediateOutAndOut) {
    if (d_intermediate) {
      inter_val = paddle::platform::reduceSum(inter_val, tid, h);
      if (threadIdx.x == 0) {
        d_intermediate[j] = inter_val;
      }
    }
  }
2371 2372
}

C
chengduo 已提交
2373 2374 2375
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
2376
static void FusedElemwiseAndActGradBroadcast2CUDA(
2377
    gpuStream_t stream, const T *x, const T *y, const T *intermediate_out,
2378
    const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
C
chengduo 已提交
2379 2380
    DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
    T *dintermediate) {
2381 2382 2383
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
  FusedElemwiseAndActGradBroadcast2CUDAKernel<
C
chengduo 已提交
2384
      T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
2385
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
C
chengduo 已提交
2386 2387
      x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op,
      dintermediate_op, dx, dy, dintermediate);
2388 2389 2390 2391
}
#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
2392
          typename DIntermediate_OP, bool UseIntermediateOut, bool BcastY,
2393 2394 2395 2396 2397 2398
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
C
chengduo 已提交
2399 2400 2401
    framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
2402 2403 2404 2405
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

2406
  int pre, n, post, is_run_common_broadcast;
2407 2408
  pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
                              &is_run_common_broadcast);
2409 2410 2411 2412
  const T *x_data = nullptr;
  const T *y_data = nullptr;
  if (x->IsInitialized()) x_data = x->data<T>();
  if (y->IsInitialized()) y_data = y->data<T>();
2413 2414 2415
  if (post == 1) {
    int h = pre;
    int w = n;
2416

2417
    if (platform::is_gpu_place(ctx.GetPlace())) {
2418
#if defined(__NVCC__) || defined(__HIPCC__)
C
chengduo 已提交
2419 2420
      FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
                                            UseIntermediateOut, BcastY,
2421
                                            SameShapeOfIntermediateOutAndOut>(
2422
          ctx, x_data, y_data,
2423
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
C
chengduo 已提交
2424
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
2425
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
2426 2427 2428
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
2429 2430
#endif
    } else {
C
chengduo 已提交
2431 2432
      FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
                                           UseIntermediateOut, BcastY,
2433
                                           SameShapeOfIntermediateOutAndOut>(
2434
          x_data, y_data,
2435
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
C
chengduo 已提交
2436
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
2437
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
2438 2439 2440
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
2441 2442 2443
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
2444
#if defined(__NVCC__) || defined(__HIPCC__)
C
chengduo 已提交
2445 2446
      FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
                                            UseIntermediateOut, BcastY,
2447
                                            SameShapeOfIntermediateOutAndOut>(
2448
          ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
2449 2450
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
C
chengduo 已提交
2451
          dintermediate_op,
2452
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
2453 2454 2455
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
2456 2457
#endif
    } else {
C
chengduo 已提交
2458 2459
      FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
                                           UseIntermediateOut, BcastY,
2460
                                           SameShapeOfIntermediateOutAndOut>(
2461
          x_data, y_data,
2462 2463
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
C
chengduo 已提交
2464
          dintermediate_op,
2465
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
2466 2467 2468
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
2469 2470 2471 2472 2473
    }
  }
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
2474 2475
          typename DIntermediate_OP, bool UseIntermediateOut,
          bool SameShapeOfIntermediateOutAndOut>
2476 2477 2478 2479
void FusedElemwiseAndActGradComputeEx(
    const framework::ExecutionContext &ctx, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *out,
    const framework::Tensor *intermediate_out, const framework::Tensor *dout,
C
chengduo 已提交
2480 2481 2482
    int axis, framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
2483 2484 2485
  const framework::DDim &x_dim = x->dims();
  const framework::DDim &y_dim = y->dims();
  if (UseIntermediateOut) {
2486 2487 2488
    PADDLE_ENFORCE_NOT_NULL(
        intermediate_out,
        platform::errors::InvalidArgument("Intermediate out is null pointer."));
2489 2490
  }
  if (x_dim == y_dim) {
C
chengduo 已提交
2491 2492
    FusedElemwiseAndActGradComputeNoBroadcast<
        DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>(
2493
        ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
C
chengduo 已提交
2494
        dintermediate, dx_op, dy_op, dintermediate_op);
2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509
  } else {  // Y is a scalar
    bool bcast_y = x_dim.size() >= y_dim.size();
    if (x_dim.size() == y_dim.size()) {
      for (int i = 0; i < x_dim.size(); ++i) {
        if (x_dim[i] < y_dim[i]) {
          bcast_y = false;
          break;
        }
      }
    }

    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      FusedElemwiseAndActGradComputeWithBroadcast<
C
chengduo 已提交
2510 2511 2512 2513
          DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
          true /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
          ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
          dintermediate, dx_op, dy_op, dintermediate_op);
2514 2515
    } else {
      FusedElemwiseAndActGradComputeWithBroadcast<
C
chengduo 已提交
2516 2517 2518 2519
          DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
          false /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
          ctx, y_dim, x_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
          dintermediate, dx_op, dy_op, dintermediate_op);
2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532
    }
  }
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
                                  const framework::Tensor &x,
                                  const framework::Tensor &y, int axis,
                                  CompoundFunctor compound_functor,
                                  framework::Tensor *out,
                                  framework::Tensor *intermediate_out) {
  if (KeepIntermediateOut) {
2533 2534 2535 2536 2537
    PADDLE_ENFORCE_NOT_NULL(
        intermediate_out,
        platform::errors::InvalidArgument(
            "The save_intermediate_out is opened, intermediate "
            "out is null pointer."));
2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548
  }

  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
  if (x.dims() == y.dims()) {
    FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
                                          KeepIntermediateOut>(
        ctx, x_dim, x, y, compound_functor, out, intermediate_out);
  } else {
    // Whether the shape of Y is a continuous subsequence of X,
    // For more information please refer to the op's introduction.
2549
    bool bcast_y = x.numel() >= y.numel();
2550 2551 2552 2553
    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      // In this case,
2554 2555
      // for 'f2(y)', the shape of intermediate_out should be equal to the
      // shape
2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566
      // of Y.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of X.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, true /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    } else {
      // In this case,
2567 2568
      // for 'f2(y)', the shape of intermediate_out should be equal to the
      // shape
2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580
      // of Out.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of Y.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, false /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    }
  }
}
2581 2582 2583 2584 2585 2586 2587 2588

template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(
    const framework::ExecutionContext &ctx, const framework::Tensor *x,
    const framework::Tensor *ddx, framework::Tensor *ddx_safe) {
  if (ddx) {
    *ddx_safe = *ddx;
  } else {
2589 2590
    auto &dev_ctx = ctx.template device_context<DeviceContext>();
    *ddx_safe = ctx.AllocateTmpTensor<T, DeviceContext>(x->dims(), dev_ctx);
2591 2592 2593 2594 2595 2596
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(ctx.template device_context<DeviceContext>(), ddx_safe,
             static_cast<T>(0));
  }
}

2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616
// for broadcast backwards
static inline std::vector<int> GetReduceDim(const framework::DDim &in,
                                            const framework::DDim &out,
                                            int axis) {
  axis =
      (axis == -1 ? std::abs(static_cast<int>(out.size() - in.size())) : axis);
  std::vector<int> dims;
  for (int i = 0; i < axis; ++i) {
    dims.push_back(i);
  }
  for (int i = 0; i < in.size(); ++i) {
    if (out[i + axis] != in[i]) {
      dims.push_back(i + axis);
    }
  }
  for (int i = axis + in.size(); i < out.size(); ++i) {
    dims.push_back(i);
  }
  return dims;
}
2617 2618
}  // namespace operators
}  // namespace paddle