elementwise_op_function.h 58.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16

17
#include <glog/logging.h>
18
#include <algorithm>
D
dzhwinter 已提交
19
#include <iterator>
20
#include <vector>
Y
Yi Wang 已提交
21 22 23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
25

C
chengduoZH 已提交
26
#ifdef __NVCC__
27
#include <cuda.h>
C
chengduoZH 已提交
28
#include <thrust/iterator/iterator_adaptor.h>
29
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
30
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
31
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
32 33
#endif

Y
Yi Wang 已提交
34
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
35
#include "paddle/fluid/platform/for_range.h"
36 37 38 39 40 41 42 43 44 45

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
46
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
47 48
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
49
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
50
 */
51 52 53
inline void get_mid_dims(const framework::DDim &x_dims,
                         const framework::DDim &y_dims, const int axis,
                         int *pre, int *n, int *post) {
54 55 56
  *pre = 1;
  *n = 1;
  *post = 1;
57
  for (int i = 0; i < axis; ++i) {
58
    (*pre) *= x_dims[i];
59 60 61 62 63
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
64
    (*n) *= y_dims[i];
65 66 67
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
68
    (*post) *= x_dims[i];
69 70 71
  }
}

72
inline framework::DDim trim_trailing_singular_dims(
73
    const framework::DDim &dims) {
74
  // Remove trailing dimensions of size 1 for y
75
  auto actual_dims_size = dims.size();
76
  for (; actual_dims_size != 0; --actual_dims_size) {
77
    if (dims[actual_dims_size - 1] != 1) break;
78
  }
79 80 81 82 83

  std::vector<int> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
84
  }
85 86 87
  if (trim_dims.size() == 0) {
    return framework::DDim(framework::make_dim());
  }
88 89
  framework::DDim actual_dims = framework::make_ddim(trim_dims);
  return actual_dims;
90 91
}

Q
QI JUN 已提交
92
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
93
class RowwiseTransformIterator;
94

Q
QI JUN 已提交
95
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
96
class MidWiseTransformIterator;
C
chengduoZH 已提交
97

D
dzhwinter 已提交
98
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
C
chengduoZH 已提交
99
template <typename T>
D
dzhwinter 已提交
100 101 102
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
    : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
                           T *, T &> {
C
chengduoZH 已提交
103
 public:
104
  RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
C
chengduoZH 已提交
105

106
  RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
C
chengduoZH 已提交
107
    ++i_;
C
chengduoZH 已提交
108 109 110
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
111 112 113
    return *this;
  }

P
peizhilin 已提交
114
  RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
P
peizhilin 已提交
115
    while (n-- > 0) {
P
peizhilin 已提交
116 117 118 119 120 121 122 123 124
      ++i_;
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
    }

    return *this;
  }

125 126
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
127
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
128 129
  }

130 131
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
132
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
133 134
  }

135
  const T &operator*() { return ptr_[i_]; }
C
chengduoZH 已提交
136

C
chengduoZH 已提交
137
 private:
138
  const T *ptr_;
C
chengduoZH 已提交
139
  int i_;
C
chengduoZH 已提交
140
  int64_t n_;
C
chengduoZH 已提交
141 142 143
};

template <typename T>
D
dzhwinter 已提交
144 145 146
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
    : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
                           T *, T &> {
C
chengduoZH 已提交
147
 public:
148
  MidWiseTransformIterator(const T *ptr, int n, int post)
C
chengduoZH 已提交
149 150
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

151
  MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
C
chengduoZH 已提交
152
    ++j_;
C
chengduoZH 已提交
153 154
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
155
      j_ = 0;
C
chengduoZH 已提交
156 157 158
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
159
    }
C
chengduoZH 已提交
160 161 162
    return *this;
  }

P
peizhilin 已提交
163
  MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
P
peizhilin 已提交
164
    while (n-- > 0) {
P
peizhilin 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177
      ++j_;
      if (UNLIKELY(j_ == post_)) {
        ++i_;
        j_ = 0;
        if (UNLIKELY(i_ == n_)) {
          i_ = 0;
        }
      }
    }

    return *this;
  }

178 179
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
180
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
181 182
  }

183 184
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
185
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
186 187
  }

188
  const T &operator*() { return ptr_[i_]; }
C
chengduoZH 已提交
189

C
chengduoZH 已提交
190
 private:
191
  const T *ptr_;
C
refine  
chengduoZH 已提交
192
  int64_t i_;
C
chengduoZH 已提交
193 194
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
195
  int64_t post_;
C
chengduoZH 已提交
196 197
};

C
chengduoZH 已提交
198 199
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
200
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
201
    : public thrust::iterator_adaptor<
202
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
C
chengduoZH 已提交
203 204
 public:
  typedef thrust::iterator_adaptor<
205
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
C
chengduoZH 已提交
206
      super_t;
207
  HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
208
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
209 210 211 212
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
213
  const T *begin_;
C
chengduoZH 已提交
214
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
215 216 217 218 219
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
220
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
221
    : public thrust::iterator_adaptor<
222
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
C
chengduoZH 已提交
223 224
 public:
  typedef thrust::iterator_adaptor<
225
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
C
chengduoZH 已提交
226
      super_t;
227
  HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
228
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
229 230 231 232 233
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
234
  const T *begin_;
C
chengduoZH 已提交
235
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
236 237 238 239 240
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

241 242
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
243 244
class TransformFunctor {
 public:
245 246
  TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
                   framework::Tensor *z, const DeviceContext &ctx, Functor func)
C
chengduoZH 已提交
247 248
      : x_(x->data<T>()),
        y_(y->data<T>()),
249
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
250 251 252 253 254
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
255
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
256
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
257 258 259
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
260 261 262
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
263 264 265
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
266 267 268
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
269 270
  }

C
chengduoZH 已提交
271
 private:
272 273 274
  const T *x_;
  const T *y_;
  OutType *z_;
C
chengduoZH 已提交
275
  int64_t nx_;
276
  const DeviceContext &ctx_;
C
chengduoZH 已提交
277 278 279
  Functor func_;
};

Y
Yu Yang 已提交
280 281
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
282 283 284 285
  const T *x_;
  const T *y_;
  const T *out_;
  const T *dout_;
Y
Yu Yang 已提交
286 287 288 289 290 291

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
292
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
293 294 295 296 297
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
298 299
  T *dx_;
  T *dy_;
Y
Yu Yang 已提交
300 301 302
};

template <typename T, typename DX_OP, typename DY_OP>
303 304 305
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
                                      const T *dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
323

D
dzhwinter 已提交
324
#ifdef __NVCC__
Y
Yu Yang 已提交
325 326
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
327 328
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
329 330 331
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
C
chengduo 已提交
332
  T val(0);
Y
Yu Yang 已提交
333 334 335 336 337 338 339

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
340
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
341 342 343 344 345
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
346
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
347
    val = paddle::platform::reduceSum(val, tid, h);
Y
Yu Yang 已提交
348
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
349
      dy[j] = val;
Y
Yu Yang 已提交
350 351 352 353
    }
  }
}

354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
#define BLOCK_X 32
#define BLOCK_Y 32

// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
  __shared__ T sdata[BLOCK_Y][BLOCK_X + 1];

  T val(0);
  size_t width_stride = gridDim.x * blockDim.x;
  size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
  size_t full_width =
      (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
  size_t full_height =
      (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);

  for (int m = idx; m < full_width; m += width_stride) {
    sdata[threadIdx.y][threadIdx.x] = 0;
    for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
      int x_offset = n * w + m;
      if (dx && m < w && n < h) {
        dx[x_offset] = dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
      }
      if (dy) {
        if (m < w && n < h) {
          T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
          sdata[threadIdx.y][threadIdx.x] += val;
        }
        __syncthreads();
      }
    }
    if (dy) {
      T my_val = sdata[threadIdx.x][threadIdx.y];
      for (int i = warpSize >> 1; i > 0; i >>= 1)
        my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
      __syncthreads();
      if ((threadIdx.x == 0)) {
        sdata[0][threadIdx.y] = my_val;
      }
      __syncthreads();
      if (threadIdx.y == 0 && m < w) {
        dy[m] = sdata[0][threadIdx.x];
      }
    }
  }
}

Y
Yu Yang 已提交
404
template <typename T, typename DX_OP, typename DY_OP>
405 406
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
                                       const T *y, const T *out, const T *dout,
Y
Yu Yang 已提交
407
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
408
                                       T *dx, T *dy) {
409 410 411 412
  // suppose perfoemance improves with h increased.
  dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
  int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
  FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
C
chengduoZH 已提交
413
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
414 415 416 417 418
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
419 420 421
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
                                      const T *dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
446 447
    const T *x, const T *y, const T *out, const T *dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
448 449 450
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduo 已提交
451
  T val(0);
Y
Yu Yang 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
466
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
467 468 469 470 471 472
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
473 474
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
475
    val = paddle::platform::reduceSum(val, tid, h);
C
chengduoZH 已提交
476
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
477
      dy[j] = val;
Y
Yu Yang 已提交
478 479 480 481 482
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
483 484
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
                                       const T *y, const T *out, const T *dout,
Y
Yu Yang 已提交
485
                                       int pre, int n, int post, DX_OP dx_op,
486
                                       DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
487 488
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
489 490
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
491 492 493 494
}

#endif

495 496
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
497 498 499 500 501
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
502
  size_t N = static_cast<size_t>(framework::product(x_dim));
D
dzhwinter 已提交
503
#if !defined(_WIN32)
504 505
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
D
dzhwinter 已提交
506 507 508 509
#else
  platform::ForRange<DeviceContext> for_range(
      ctx.device_context<DeviceContext>(), N);
#endif  // !_WIN32
510 511 512 513 514 515 516 517
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
518 519 520 521 522
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast1CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast1CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast2CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast2CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
          dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

Y
Yu Yang 已提交
565
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
566 567 568 569 570
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
                         const framework::Tensor &x, const framework::Tensor &y,
                         const framework::Tensor &out,
                         const framework::Tensor &dout, int axis,
                         framework::Tensor *dx, framework::Tensor *dy,
Y
Yu Yang 已提交
571
                         DX_OP dx_op, DY_OP dy_op) {
572 573
  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
Y
Yu Yang 已提交
574
  if (x.dims() == y.dims()) {
575 576
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
577
  } else {  // Y is a scalar
578 579 580 581 582 583 584 585 586 587
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
588 589 590 591 592 593
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
                                 const framework::Tensor &x,
                                 const framework::Tensor &y,
                                 const framework::Tensor &out,
                                 const framework::Tensor &dout, int axis,
                                 framework::Tensor *dx, framework::Tensor *dy,
594 595
                                 DX_OP dx_op, DY_OP dy_op) {
  if (dy == nullptr) {
596
    const framework::DDim &dx_dims = dout.dims();
597 598 599 600 601
    auto dy_dims = dx_dims;
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  } else {
    if (dout.dims() == dy->dims()) {
602 603
      const framework::DDim &dx_dims = dout.dims();
      const framework::DDim &dy_dims = dy->dims();
604 605 606 607
      ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    } else {  // Y is a scalar
      auto dx_dims = dout.dims();
608
      const framework::DDim &dy_dims = dy->dims();
609 610
      ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
611 612
    }
  }
613
}
Y
Yu Yang 已提交
614

615
// Deprecated
Q
QI JUN 已提交
616
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
617
          typename broadcastfunctor, typename broadcast2functor>
618 619 620 621 622 623 624
void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
                            const framework::Tensor *x,
                            const framework::Tensor *y,
                            const framework::Tensor *out,
                            const framework::Tensor *dout, int axis,
                            framework::Tensor *dx, framework::Tensor *dy) {
  auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
643
  trim_trailing_singular_dims(y_dims);
644
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
645 646

  int pre, n, post;
647
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
648 649 650 651 652 653 654 655 656 657 658

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
659

660 661
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
D
dzhwinter 已提交
662

663 664 665 666
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
                          const framework::Tensor *x,
                          const framework::Tensor *y, int axis, Functor func,
                          framework::Tensor *z) {
667
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
668
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
669
  auto x_dims = x->dims();
670 671
  auto y_dims_untrimed = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
F
fengjiayi 已提交
672
                    "Rank of first input must >= rank of second input.");
673
  if (x_dims == y_dims_untrimed) {
F
fengjiayi 已提交
674 675 676 677
    functor.Run();
    return;
  }

678
  axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
F
fengjiayi 已提交
679 680
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
681
  auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
682
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
683 684

  int pre, n, post;
685
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
686 687 688 689 690 691 692 693 694
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
    T y_val = y_[i];
    T x_val = x_[i];
    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
      intermediate_out_[i] = intermeidiate_out;
      out_[i] =
          compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out_[i] = compound_functor_.GetOut(x_val, y_val);
    }
  }

  const T *x_;
  const T *y_;
  CompoundFunctor compound_functor_;
  T *out_;
  T *intermediate_out_;
};

// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
                                             CompoundFunctor compound_functor,
                                             int h, int w, T *out,
                                             T *intermediate_out) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      T y_val = BcastY ? y[j] : y[offset];
      T x_val = BcastY ? x[offset] : x[j];
      int64_t intermediate_out_offset;
      if (KeepIntermediateOut) {
        T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

        if (SameShapeOfIntermediateOutAndOut) {
          // for the case of f1(f2(x, y))
          intermediate_out_offset = offset;
        } else if (BcastY) {
          intermediate_out_offset = j;
        } else {
          intermediate_out_offset = offset;
        }

        intermediate_out[intermediate_out_offset] = intermeidiate_out;
        out[offset] =
            compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
      } else {
        out[offset] = compound_functor.GetOut(x_val, y_val);
      }
    }
  }
}

// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
                                             int n, int post,
                                             CompoundFunctor compound_functor,
                                             T *out, T *intermediate_out) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        T y_val = BcastY ? y[j] : y[offset];
        T x_val = BcastY ? x[offset] : x[j];
        int64_t intermediate_out_offset;

        if (KeepIntermediateOut) {
          T intermeidiate_out =
              compound_functor.GetIntermediateOut(x_val, y_val);

          if (SameShapeOfIntermediateOutAndOut) {
            // for the case of f1(f2(x, y))
            intermediate_out_offset = offset;
          } else if (BcastY) {
            intermediate_out_offset = j;
          } else {
            intermediate_out_offset = offset;
          }

          intermediate_out[intermediate_out_offset] = intermeidiate_out;
          out[offset] = compound_functor.GetOutUseIntermediateOut(
              x_val, intermeidiate_out);
        } else {
          out[offset] = compound_functor.GetOut(x_val, y_val);
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
    const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
    T *out, T *intermediate_out) {
  int j = blockIdx.x;
  int i = threadIdx.x;

  while (i < h) {
    int offset = i * w + j;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

    i += ELEMWISE_MAX_BLOCK_DIM;
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x,
                                              const T *y,
                                              CompoundFunctor compound_functor,
                                              int h, int w, T *out,
                                              T *intermediate_out) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
  FusedElemwiseAndActBroadcast1CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, h, w, compound_functor, out, intermediate_out);
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
    const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
    int post, T *out, T *intermediate_out) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

  while (true) {
    int i = tid / post;
    int k = tid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

    tid += ELEMWISE_MAX_BLOCK_DIM;
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x,
                                              const T *y, int pre, int n,
                                              int post,
                                              CompoundFunctor compound_functor,
                                              T *out, T *intermediate_out) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;

  FusedElemwiseAndActBroadcast2CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, compound_functor, pre, n, post, out, intermediate_out);
}

#endif

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::Tensor &x, const framework::Tensor &y,
    CompoundFunctor compound_functor, framework::Tensor *out,
    framework::Tensor *intermediate_out) {
  size_t N = static_cast<size_t>(framework::product(x_dim));

  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);

  for_range(
      FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
          x.data<T>(), y.data<T>(), compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool BcastY, bool KeepIntermediateOut,
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
    const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
    framework::Tensor *out, framework::Tensor *intermediate_out) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);

  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

// --- backward
C
chengduo 已提交
1004 1005
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut>
1006 1007 1008
struct FusedElemwiseAndActGradNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
C
chengduo 已提交
1009 1010 1011 1012
      dx_[i] = UseIntermediateOut
                   ? dx_op_.UseIntermediateOut(
                         x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
                   : dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
1013 1014
    }
    if (dy_ != nullptr) {
C
chengduo 已提交
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
      dy_[i] = UseIntermediateOut
                   ? dy_op_.UseIntermediateOut(
                         x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
                   : dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dintermediate_ != nullptr) {
      dintermediate_[i] =
          UseIntermediateOut
              ? dintermediate_op_.UseIntermediateOut(
                    x_[i], intermediate_out_[i], out_[i], dout_[i])
              : dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
    }
  }

  const T *x_;
  const T *y_;
  const T *intermediate_out_;
  const T *out_;
  const T *dout_;
  DX_OP dx_op_;
  DY_OP dy_op_;
C
chengduo 已提交
1036
  DIntermediate_OP dintermediate_op_;
1037 1038
  T *dx_;
  T *dy_;
C
chengduo 已提交
1039
  T *dintermediate_;
1040 1041 1042
};

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
1043
          typename DIntermediate_OP, bool UseIntermediateOut>
1044 1045 1046 1047 1048
void FusedElemwiseAndActGradComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
C
chengduo 已提交
1049 1050 1051
    framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
1052 1053 1054 1055
  size_t N = static_cast<size_t>(framework::product(x_dim));
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
  for_range(
C
chengduo 已提交
1056 1057
      FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, DIntermediate_OP,
                                         UseIntermediateOut>{
1058 1059
          x->data<T>(), y->data<T>(),
          intermediate_out ? intermediate_out->data<T>() : nullptr,
C
chengduo 已提交
1060
          out->data<T>(), dout->data<T>(), dx_op, dy_op, dintermediate_op,
1061
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
1062 1063 1064
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace())});
1065 1066
}

C
chengduo 已提交
1067 1068 1069 1070 1071 1072 1073
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
  int64_t tmp_out_idx, x_idx, y_idx;
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      tmp_out_idx = BcastY ? j : offset;
      y_idx = BcastY ? j : offset;
      x_idx = BcastY ? offset : j;

      if (SameShapeOfIntermediateOutAndOut) {
        tmp_out_idx = offset;
      }

      if (dx != nullptr) {
        T tmp = UseIntermediateOut
C
chengduo 已提交
1089 1090 1091 1092 1093
                    ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                               intermediate_out[tmp_out_idx],
                                               out[offset], dout[offset])
                    : dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                      dout[offset]);
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106

        if (BcastY) {
          dx[x_idx] = tmp;
        } else {
          if (i == 0) {
            dx[x_idx] = tmp;
          } else {
            dx[x_idx] += tmp;
          }
        }
      }
      if (dy != nullptr) {
        T tmp = UseIntermediateOut
C
chengduo 已提交
1107 1108 1109 1110 1111
                    ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                               intermediate_out[tmp_out_idx],
                                               out[offset], dout[offset])
                    : dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                      dout[offset]);
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
        if (BcastY) {
          if (i == 0) {
            dy[y_idx] = tmp;
          } else {
            dy[y_idx] += tmp;
          }
        } else {
          dy[y_idx] = tmp;
        }
      }
C
chengduo 已提交
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138
      if (d_intermediate != nullptr) {
        T tmp = UseIntermediateOut
                    ? dintermediate_op.UseIntermediateOut(
                          x[x_idx], intermediate_out[tmp_out_idx], out[offset],
                          dout[offset])
                    : dintermediate_op.Recompute(x[x_idx], y[y_idx],
                                                 out[offset], dout[i]);
        if (SameShapeOfIntermediateOutAndOut) {
          d_intermediate[tmp_out_idx] = tmp;
        } else {
          if (i == 0) {
            d_intermediate[tmp_out_idx] = tmp;
          } else {
            d_intermediate[tmp_out_idx] += tmp;
          }
        }
      }
1139 1140 1141 1142
    }
  }
}

C
chengduo 已提交
1143 1144 1145 1146 1147 1148 1149
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
  int64_t tmp_out_idx, x_idx, y_idx;
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        tmp_out_idx = BcastY ? j : offset;
        y_idx = BcastY ? j : offset;
        x_idx = BcastY ? offset : j;

        if (SameShapeOfIntermediateOutAndOut) {
          tmp_out_idx = offset;
        }

        if (dx != nullptr) {
          T tmp = UseIntermediateOut
C
chengduo 已提交
1166 1167 1168 1169 1170
                      ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                                 intermediate_out[tmp_out_idx],
                                                 out[offset], dout[offset])
                      : dx_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                        dout[offset]);
1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183

          if (BcastY) {
            dx[x_idx] = tmp;
          } else {
            if (i == 0 && k == 0) {
              dx[x_idx] = tmp;
            } else {
              dx[x_idx] += tmp;
            }
          }
        }
        if (dy != nullptr) {
          T tmp = UseIntermediateOut
C
chengduo 已提交
1184 1185 1186 1187 1188
                      ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                                 intermediate_out[tmp_out_idx],
                                                 out[offset], dout[offset])
                      : dy_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                        dout[offset]);
1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
          if (BcastY) {
            if (i == 0 && k == 0) {
              dy[y_idx] = tmp;
            } else {
              dy[y_idx] += tmp;
            }
          } else {
            dy[y_idx] = tmp;
          }
        }
C
chengduo 已提交
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215
        if (d_intermediate != nullptr) {
          T tmp = UseIntermediateOut
                      ? dintermediate_op.UseIntermediateOut(
                            x[x_idx], intermediate_out[tmp_out_idx],
                            out[offset], dout[offset])
                      : dintermediate_op.Recompute(x[x_idx], y[y_idx],
                                                   out[offset], dout[i]);
          if (SameShapeOfIntermediateOutAndOut) {
            d_intermediate[tmp_out_idx] = tmp;
          } else {
            if (i == 0) {
              d_intermediate[tmp_out_idx] = tmp;
            } else {
              d_intermediate[tmp_out_idx] += tmp;
            }
          }
        }
1216 1217 1218 1219 1220 1221
      }
    }
  }
}

#ifdef __NVCC__
C
chengduo 已提交
1222 1223 1224
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
1225 1226
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
C
chengduo 已提交
1227 1228
    const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1229 1230 1231
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
C
chengduo 已提交
1232
  T val(0), inter_val(0);
1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246
  int64_t tmp_out_idx, x_idx, y_idx;

  do {
    int offset = i * w + j;

    tmp_out_idx = BcastY ? j : offset;
    y_idx = BcastY ? j : offset;
    x_idx = BcastY ? offset : j;

    if (SameShapeOfIntermediateOutAndOut) {
      tmp_out_idx = offset;
    }

    if (dx != nullptr) {
C
chengduo 已提交
1247 1248 1249 1250 1251 1252
      T tmp =
          UseIntermediateOut
              ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                         intermediate_out[tmp_out_idx],
                                         out[offset], dout[offset])
              : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
1253 1254 1255 1256 1257 1258 1259 1260

      if (BcastY) {
        dx[x_idx] = tmp;
      } else {
        val += tmp;
      }
    }
    if (dy != nullptr) {
C
chengduo 已提交
1261 1262 1263 1264 1265 1266
      T tmp =
          UseIntermediateOut
              ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                         intermediate_out[tmp_out_idx],
                                         out[offset], dout[offset])
              : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
1267 1268 1269 1270 1271 1272
      if (BcastY) {
        val += tmp;
      } else {
        dy[y_idx] = tmp;
      }
    }
C
chengduo 已提交
1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285
    if (d_intermediate != nullptr) {
      T tmp = UseIntermediateOut
                  ? dintermediate_op.UseIntermediateOut(
                        y[y_idx], intermediate_out[tmp_out_idx], out[offset],
                        dout[offset])
                  : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                               dout[offset]);
      if (SameShapeOfIntermediateOutAndOut) {
        d_intermediate[tmp_out_idx] = tmp;
      } else {
        inter_val += tmp;
      }
    }
1286 1287 1288 1289

    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

C
chengduo 已提交
1290
  h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
  if (BcastY) {
    if (dy) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    if (dx) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
C
chengduo 已提交
1306 1307 1308 1309 1310 1311 1312 1313
  if (!SameShapeOfIntermediateOutAndOut) {
    if (d_intermediate) {
      inter_val = paddle::platform::reduceSum(inter_val, tid, h);
      if (threadIdx.x == 0) {
        d_intermediate[j] = inter_val;
      }
    }
  }
1314 1315
}

C
chengduo 已提交
1316 1317 1318 1319 1320 1321 1322
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(
    cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
    const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1323 1324 1325
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
  FusedElemwiseAndActGradBroadcast1CUDAKernel<
C
chengduo 已提交
1326
      T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
1327
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
C
chengduo 已提交
1328 1329
      x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
      dx, dy, d_intermediate);
1330 1331
}

C
chengduo 已提交
1332 1333 1334
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
1335 1336
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
C
chengduo 已提交
1337 1338
    const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
1339 1340 1341
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduo 已提交
1342
  T val(0), inter_val(0);
1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
  int ttid = tid;
  int64_t tmp_out_idx, x_idx, y_idx;
  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    tmp_out_idx = BcastY ? j : offset;
    y_idx = BcastY ? j : offset;
    x_idx = BcastY ? offset : j;

    if (SameShapeOfIntermediateOutAndOut) {
      tmp_out_idx = offset;
    }

    if (dx != nullptr) {
C
chengduo 已提交
1361 1362 1363 1364 1365 1366
      T tmp =
          UseIntermediateOut
              ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                         intermediate_out[tmp_out_idx],
                                         out[offset], dout[offset])
              : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
1367 1368 1369 1370 1371 1372 1373 1374

      if (BcastY) {
        dx[x_idx] = tmp;
      } else {
        val += tmp;
      }
    }
    if (dy != nullptr) {
C
chengduo 已提交
1375 1376 1377 1378 1379 1380
      T tmp =
          UseIntermediateOut
              ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx],
                                         intermediate_out[tmp_out_idx],
                                         out[offset], dout[offset])
              : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]);
1381 1382 1383 1384 1385 1386
      if (BcastY) {
        val += tmp;
      } else {
        dy[y_idx] = tmp;
      }
    }
C
chengduo 已提交
1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399
    if (d_intermediate != nullptr) {
      T tmp = UseIntermediateOut
                  ? dintermediate_op.UseIntermediateOut(
                        y[y_idx], intermediate_out[tmp_out_idx], out[offset],
                        dout[offset])
                  : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset],
                                               dout[offset]);
      if (SameShapeOfIntermediateOutAndOut) {
        d_intermediate[tmp_out_idx] = tmp;
      } else {
        inter_val += tmp;
      }
    }
1400 1401 1402
    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

C
chengduo 已提交
1403 1404
  int h = pre * post;
  h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
  if (BcastY) {
    if (dy) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    if (dx) {
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
C
chengduo 已提交
1420 1421 1422 1423 1424 1425 1426 1427
  if (!SameShapeOfIntermediateOutAndOut) {
    if (d_intermediate) {
      inter_val = paddle::platform::reduceSum(inter_val, tid, h);
      if (threadIdx.x == 0) {
        d_intermediate[j] = inter_val;
      }
    }
  }
1428 1429
}

C
chengduo 已提交
1430 1431 1432
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
1433 1434 1435
static void FusedElemwiseAndActGradBroadcast2CUDA(
    cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
    const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
C
chengduo 已提交
1436 1437
    DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
    T *dintermediate) {
1438 1439 1440
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
  FusedElemwiseAndActGradBroadcast2CUDAKernel<
C
chengduo 已提交
1441
      T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
1442
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
C
chengduo 已提交
1443 1444
      x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op,
      dintermediate_op, dx, dy, dintermediate);
1445 1446 1447 1448
}
#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
1449
          typename DIntermediate_OP, bool UseIntermediateOut, bool BcastY,
1450 1451 1452 1453 1454 1455
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
C
chengduo 已提交
1456 1457 1458
    framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
C
chengduo 已提交
1470 1471
      FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
                                            UseIntermediateOut, BcastY,
1472 1473 1474 1475
                                            SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
          y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
C
chengduo 已提交
1476
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
1477
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
1478 1479 1480
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
1481 1482
#endif
    } else {
C
chengduo 已提交
1483 1484
      FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, DIntermediate_OP,
                                           UseIntermediateOut, BcastY,
1485 1486 1487
                                           SameShapeOfIntermediateOutAndOut>(
          x->data<T>(), y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
C
chengduo 已提交
1488
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
1489
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
1490 1491 1492
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
1493 1494 1495 1496
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
C
chengduo 已提交
1497 1498
      FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
                                            UseIntermediateOut, BcastY,
1499 1500 1501 1502 1503
                                            SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
          y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
C
chengduo 已提交
1504
          dintermediate_op,
1505
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
1506 1507 1508
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
1509 1510
#endif
    } else {
C
chengduo 已提交
1511 1512
      FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, DIntermediate_OP,
                                           UseIntermediateOut, BcastY,
1513 1514 1515 1516
                                           SameShapeOfIntermediateOutAndOut>(
          x->data<T>(), y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
C
chengduo 已提交
1517
          dintermediate_op,
1518
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
C
chengduo 已提交
1519 1520 1521
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()),
          dintermediate == nullptr ? nullptr : dintermediate->mutable_data<T>(
                                                   ctx.GetPlace()));
1522 1523 1524 1525 1526
    }
  }
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
C
chengduo 已提交
1527 1528
          typename DIntermediate_OP, bool UseIntermediateOut,
          bool SameShapeOfIntermediateOutAndOut>
1529 1530 1531 1532
void FusedElemwiseAndActGradComputeEx(
    const framework::ExecutionContext &ctx, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *out,
    const framework::Tensor *intermediate_out, const framework::Tensor *dout,
C
chengduo 已提交
1533 1534 1535
    int axis, framework::Tensor *dx, framework::Tensor *dy,
    framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op,
    DIntermediate_OP dintermediate_op) {
1536 1537 1538 1539 1540 1541
  const framework::DDim &x_dim = x->dims();
  const framework::DDim &y_dim = y->dims();
  if (UseIntermediateOut) {
    PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
  }
  if (x_dim == y_dim) {
C
chengduo 已提交
1542 1543
    FusedElemwiseAndActGradComputeNoBroadcast<
        DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>(
1544
        ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
C
chengduo 已提交
1545
        dintermediate, dx_op, dy_op, dintermediate_op);
1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560
  } else {  // Y is a scalar
    bool bcast_y = x_dim.size() >= y_dim.size();
    if (x_dim.size() == y_dim.size()) {
      for (int i = 0; i < x_dim.size(); ++i) {
        if (x_dim[i] < y_dim[i]) {
          bcast_y = false;
          break;
        }
      }
    }

    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      FusedElemwiseAndActGradComputeWithBroadcast<
C
chengduo 已提交
1561 1562 1563 1564
          DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
          true /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
          ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
          dintermediate, dx_op, dy_op, dintermediate_op);
1565 1566
    } else {
      FusedElemwiseAndActGradComputeWithBroadcast<
C
chengduo 已提交
1567 1568 1569 1570
          DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut,
          false /*BcastY*/, SameShapeOfIntermediateOutAndOut>(
          ctx, y_dim, x_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
          dintermediate, dx_op, dy_op, dintermediate_op);
1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584
    }
  }
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
                                  const framework::Tensor &x,
                                  const framework::Tensor &y, int axis,
                                  CompoundFunctor compound_functor,
                                  framework::Tensor *out,
                                  framework::Tensor *intermediate_out) {
  if (KeepIntermediateOut) {
    PADDLE_ENFORCE(intermediate_out,
C
chengduo 已提交
1585
                   "The save_intermediate_out is opened, "
1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611
                   "intermediate_out should not be nullptr.");
  }

  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
  if (x.dims() == y.dims()) {
    FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
                                          KeepIntermediateOut>(
        ctx, x_dim, x, y, compound_functor, out, intermediate_out);
  } else {
    // Whether the shape of Y is a continuous subsequence of X,
    // For more information please refer to the op's introduction.
    bool bcast_y = x.dims().size() >= y.dims().size();
    if (x.dims().size() == y.dims().size()) {
      for (int i = 0; i < x.dims().size(); ++i) {
        if (x.dims()[i] < y.dims()[i]) {
          bcast_y = false;
          break;
        }
      }
    }

    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      // In this case,
1612 1613
      // for 'f2(y)', the shape of intermediate_out should be equal to the
      // shape
1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624
      // of Y.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of X.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, true /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    } else {
      // In this case,
1625 1626
      // for 'f2(y)', the shape of intermediate_out should be equal to the
      // shape
1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638
      // of Out.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of Y.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, false /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    }
  }
}
1639 1640
}  // namespace operators
}  // namespace paddle