elementwise_op_function.h 54.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16

17
#include <glog/logging.h>
18
#include <algorithm>
19
#include <vector>
Y
Yi Wang 已提交
20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
24

C
chengduoZH 已提交
25
#ifdef __NVCC__
26
#include <cuda.h>
C
chengduoZH 已提交
27
#include <thrust/iterator/iterator_adaptor.h>
28
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
29
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
30
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
31 32
#endif

Y
Yi Wang 已提交
33
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
34
#include "paddle/fluid/platform/for_range.h"
35 36 37 38 39 40 41 42 43 44

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
45
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
46 47
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
48
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
49
 */
50 51 52
inline void get_mid_dims(const framework::DDim &x_dims,
                         const framework::DDim &y_dims, const int axis,
                         int *pre, int *n, int *post) {
53 54 55
  *pre = 1;
  *n = 1;
  *post = 1;
56
  for (int i = 0; i < axis; ++i) {
57
    (*pre) *= x_dims[i];
58 59 60 61 62
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
63
    (*n) *= y_dims[i];
64 65 66
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
67
    (*post) *= x_dims[i];
68 69 70
  }
}

71
inline framework::DDim trim_trailing_singular_dims(
72
    const framework::DDim &dims) {
73
  // Remove trailing dimensions of size 1 for y
74
  auto actual_dims_size = dims.size();
75
  for (; actual_dims_size != 0; --actual_dims_size) {
76
    if (dims[actual_dims_size - 1] != 1) break;
77
  }
78 79 80 81 82

  std::vector<int> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
83
  }
84 85 86
  if (trim_dims.size() == 0) {
    return framework::DDim(framework::make_dim());
  }
87 88
  framework::DDim actual_dims = framework::make_ddim(trim_dims);
  return actual_dims;
89 90
}

Q
QI JUN 已提交
91
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
92
class RowwiseTransformIterator;
93

Q
QI JUN 已提交
94
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
95
class MidWiseTransformIterator;
C
chengduoZH 已提交
96 97

template <typename T>
Q
QI JUN 已提交
98
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
99
 public:
100
  RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
C
chengduoZH 已提交
101

102
  RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
C
chengduoZH 已提交
103
    ++i_;
C
chengduoZH 已提交
104 105 106
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
107 108 109
    return *this;
  }

110 111
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
112
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
113 114
  }

115 116
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
117
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
118 119
  }

120
  const T &operator*() { return ptr_[i_]; }
C
chengduoZH 已提交
121

C
chengduoZH 已提交
122
 private:
123
  const T *ptr_;
C
chengduoZH 已提交
124
  int i_;
C
chengduoZH 已提交
125
  int64_t n_;
C
chengduoZH 已提交
126 127 128
};

template <typename T>
Q
QI JUN 已提交
129
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
130
 public:
131
  MidWiseTransformIterator(const T *ptr, int n, int post)
C
chengduoZH 已提交
132 133
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

134
  MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
C
chengduoZH 已提交
135
    ++j_;
C
chengduoZH 已提交
136 137
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
138
      j_ = 0;
C
chengduoZH 已提交
139 140 141
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
142
    }
C
chengduoZH 已提交
143 144 145
    return *this;
  }

146 147
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
148
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
149 150
  }

151 152
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
                      &rhs) const {
C
chengduoZH 已提交
153
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
154 155
  }

156
  const T &operator*() { return ptr_[i_]; }
C
chengduoZH 已提交
157

C
chengduoZH 已提交
158
 private:
159
  const T *ptr_;
C
refine  
chengduoZH 已提交
160
  int64_t i_;
C
chengduoZH 已提交
161 162
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
163
  int64_t post_;
C
chengduoZH 已提交
164 165
};

C
chengduoZH 已提交
166 167
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
168
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
169
    : public thrust::iterator_adaptor<
170
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
C
chengduoZH 已提交
171 172
 public:
  typedef thrust::iterator_adaptor<
173
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
C
chengduoZH 已提交
174
      super_t;
175
  HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
176
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
177 178 179 180
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
181
  const T *begin_;
C
chengduoZH 已提交
182
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
183 184 185 186 187
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
188
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
189
    : public thrust::iterator_adaptor<
190
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
C
chengduoZH 已提交
191 192
 public:
  typedef thrust::iterator_adaptor<
193
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
C
chengduoZH 已提交
194
      super_t;
195
  HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
196
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
197 198 199 200 201
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
202
  const T *begin_;
C
chengduoZH 已提交
203
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
204 205 206 207 208
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

209 210
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
211 212
class TransformFunctor {
 public:
213 214
  TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
                   framework::Tensor *z, const DeviceContext &ctx, Functor func)
C
chengduoZH 已提交
215 216
      : x_(x->data<T>()),
        y_(y->data<T>()),
217
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
218 219 220 221 222
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
223
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
224
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
225 226 227
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
228 229 230
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
231 232 233
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
234 235 236
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
237 238
  }

C
chengduoZH 已提交
239
 private:
240 241 242
  const T *x_;
  const T *y_;
  OutType *z_;
C
chengduoZH 已提交
243
  int64_t nx_;
244
  const DeviceContext &ctx_;
C
chengduoZH 已提交
245 246 247
  Functor func_;
};

248 249
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
250
    template <typename DeviceContext, typename T>                              \
251 252 253
    inline void Run(const framework::Tensor *x, const framework::Tensor *y,    \
                    framework::Tensor *z,                                      \
                    const framework::ExecutionContext &ctx) {                  \
254 255 256
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
257 258 259
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
260
    }                                                                          \
Q
QI JUN 已提交
261
    template <typename DeviceContext, typename T>                              \
262 263 264
    inline void RunBroadCast(const framework::Tensor *x,                       \
                             const framework::Tensor *y, framework::Tensor *z, \
                             const framework::ExecutionContext &ctx, int pre,  \
265 266 267 268 269 270 271
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
272 273 274
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
275
    }                                                                          \
Q
QI JUN 已提交
276
    template <typename DeviceContext, typename T>                              \
277 278 279 280
    inline void RunBroadCast2(const framework::Tensor *x,                      \
                              const framework::Tensor *y,                      \
                              framework::Tensor *z,                            \
                              const framework::ExecutionContext &ctx, int pre, \
281 282 283 284 285 286 287
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
288 289 290
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
291 292 293 294
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
295

296 297 298
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
299

300 301 302
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
303

304 305 306
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
307

308 309
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
310 311
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
312 313 314 315
  const T *x_;
  const T *y_;
  const T *out_;
  const T *dout_;
Y
Yu Yang 已提交
316 317 318 319 320 321

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
322
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
323 324 325 326 327
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
328 329
  T *dx_;
  T *dy_;
Y
Yu Yang 已提交
330 331 332
};

template <typename T, typename DX_OP, typename DY_OP>
333 334 335
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
                                      const T *dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
353

D
dzhwinter 已提交
354
#ifdef __NVCC__
Y
Yu Yang 已提交
355 356
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
357 358
    const T *x, const T *y, const T *out, const T *dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
359 360 361
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
362
  T val = 0;
Y
Yu Yang 已提交
363 364 365 366 367 368 369

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
370
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
371 372 373 374 375
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
376
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
377
    val = paddle::platform::reduceSum(val, tid, h);
Y
Yu Yang 已提交
378
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
379
      dy[j] = val;
Y
Yu Yang 已提交
380 381 382 383 384
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
385 386
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
                                       const T *y, const T *out, const T *dout,
Y
Yu Yang 已提交
387
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
388
                                       T *dx, T *dy) {
Y
Yu Yang 已提交
389 390
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
391 392
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
393 394 395 396 397
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
398 399 400
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
                                      const T *dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
425 426
    const T *x, const T *y, const T *out, const T *dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
427 428 429
  int tid = threadIdx.x;
  int j = blockIdx.x;

430
  T val = 0;
Y
Yu Yang 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443 444
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
445
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
446 447 448 449 450 451
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
452 453
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
454
    val = paddle::platform::reduceSum(val, tid, h);
C
chengduoZH 已提交
455
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
456
      dy[j] = val;
Y
Yu Yang 已提交
457 458 459 460 461
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
462 463
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
                                       const T *y, const T *out, const T *dout,
Y
Yu Yang 已提交
464
                                       int pre, int n, int post, DX_OP dx_op,
465
                                       DY_OP dy_op, T *dx, T *dy) {
Y
Yu Yang 已提交
466 467
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
468 469
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
470 471 472 473
}

#endif

474 475
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
476 477 478 479 480
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
481 482 483 484 485 486 487 488 489 490 491
  size_t N = static_cast<size_t>(framework::product(x_dim));
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
492 493 494 495 496
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
    const framework::Tensor &y, const framework::Tensor &out,
    const framework::Tensor &dout, int axis, framework::Tensor *dx,
    framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast1CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast1CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast2CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast2CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
          dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

Y
Yu Yang 已提交
539
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
540 541 542 543 544
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
                         const framework::Tensor &x, const framework::Tensor &y,
                         const framework::Tensor &out,
                         const framework::Tensor &dout, int axis,
                         framework::Tensor *dx, framework::Tensor *dy,
Y
Yu Yang 已提交
545
                         DX_OP dx_op, DY_OP dy_op) {
546 547
  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
Y
Yu Yang 已提交
548
  if (x.dims() == y.dims()) {
549 550
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
551
  } else {  // Y is a scalar
552 553 554 555 556 557 558 559 560 561
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
562 563 564 565 566 567
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
                                 const framework::Tensor &x,
                                 const framework::Tensor &y,
                                 const framework::Tensor &out,
                                 const framework::Tensor &dout, int axis,
                                 framework::Tensor *dx, framework::Tensor *dy,
568 569
                                 DX_OP dx_op, DY_OP dy_op) {
  if (dy == nullptr) {
570
    const framework::DDim &dx_dims = dout.dims();
571 572 573 574 575
    auto dy_dims = dx_dims;
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  } else {
    if (dout.dims() == dy->dims()) {
576 577
      const framework::DDim &dx_dims = dout.dims();
      const framework::DDim &dy_dims = dy->dims();
578 579 580 581
      ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    } else {  // Y is a scalar
      auto dx_dims = dout.dims();
582
      const framework::DDim &dy_dims = dy->dims();
583 584
      ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
585 586
    }
  }
587
}
Y
Yu Yang 已提交
588

589
// Deprecated
Q
QI JUN 已提交
590
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
591
          typename broadcastfunctor, typename broadcast2functor>
592 593 594 595 596 597 598
void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
                            const framework::Tensor *x,
                            const framework::Tensor *y,
                            const framework::Tensor *out,
                            const framework::Tensor *dout, int axis,
                            framework::Tensor *dx, framework::Tensor *dy) {
  auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
617
  trim_trailing_singular_dims(y_dims);
618
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
619 620

  int pre, n, post;
621
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
622 623 624 625 626 627 628 629 630 631 632

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
633

634 635
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
636 637 638 639
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
                          const framework::Tensor *x,
                          const framework::Tensor *y, int axis, Functor func,
                          framework::Tensor *z) {
640
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
641
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
642 643

  auto x_dims = x->dims();
644 645
  auto y_dims_untrimed = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
F
fengjiayi 已提交
646 647
                    "Rank of first input must >= rank of second input.");

648
  if (x_dims == y_dims_untrimed) {
F
fengjiayi 已提交
649 650 651 652
    functor.Run();
    return;
  }

653
  axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
F
fengjiayi 已提交
654 655
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
656
  auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
657
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
658 659

  int pre, n, post;
660
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
661 662 663 664 665 666 667 668 669
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
    T y_val = y_[i];
    T x_val = x_[i];
    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
      intermediate_out_[i] = intermeidiate_out;
      out_[i] =
          compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out_[i] = compound_functor_.GetOut(x_val, y_val);
    }
  }

  const T *x_;
  const T *y_;
  CompoundFunctor compound_functor_;
  T *out_;
  T *intermediate_out_;
};

// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
                                             CompoundFunctor compound_functor,
                                             int h, int w, T *out,
                                             T *intermediate_out) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      T y_val = BcastY ? y[j] : y[offset];
      T x_val = BcastY ? x[offset] : x[j];
      int64_t intermediate_out_offset;
      if (KeepIntermediateOut) {
        T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

        if (SameShapeOfIntermediateOutAndOut) {
          // for the case of f1(f2(x, y))
          intermediate_out_offset = offset;
        } else if (BcastY) {
          intermediate_out_offset = j;
        } else {
          intermediate_out_offset = offset;
        }

        intermediate_out[intermediate_out_offset] = intermeidiate_out;
        out[offset] =
            compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
      } else {
        out[offset] = compound_functor.GetOut(x_val, y_val);
      }
    }
  }
}

// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
                                             int n, int post,
                                             CompoundFunctor compound_functor,
                                             T *out, T *intermediate_out) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        T y_val = BcastY ? y[j] : y[offset];
        T x_val = BcastY ? x[offset] : x[j];
        int64_t intermediate_out_offset;

        if (KeepIntermediateOut) {
          T intermeidiate_out =
              compound_functor.GetIntermediateOut(x_val, y_val);

          if (SameShapeOfIntermediateOutAndOut) {
            // for the case of f1(f2(x, y))
            intermediate_out_offset = offset;
          } else if (BcastY) {
            intermediate_out_offset = j;
          } else {
            intermediate_out_offset = offset;
          }

          intermediate_out[intermediate_out_offset] = intermeidiate_out;
          out[offset] = compound_functor.GetOutUseIntermediateOut(
              x_val, intermeidiate_out);
        } else {
          out[offset] = compound_functor.GetOut(x_val, y_val);
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
    const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
    T *out, T *intermediate_out) {
  int j = blockIdx.x;
  int i = threadIdx.x;

  while (i < h) {
    int offset = i * w + j;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

    i += ELEMWISE_MAX_BLOCK_DIM;
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x,
                                              const T *y,
                                              CompoundFunctor compound_functor,
                                              int h, int w, T *out,
                                              T *intermediate_out) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
  FusedElemwiseAndActBroadcast1CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, h, w, compound_functor, out, intermediate_out);
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
    const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
    int post, T *out, T *intermediate_out) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

  while (true) {
    int i = tid / post;
    int k = tid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    T y_val = BcastY ? y[j] : y[offset];
    T x_val = BcastY ? x[offset] : x[j];
    int64_t intermediate_out_offset;

    if (KeepIntermediateOut) {
      T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);

      if (SameShapeOfIntermediateOutAndOut) {
        // for the case of f1(f2(x, y))
        intermediate_out_offset = offset;
      } else if (BcastY) {
        intermediate_out_offset = j;
      } else {
        intermediate_out_offset = offset;
      }

      intermediate_out[intermediate_out_offset] = intermeidiate_out;
      out[offset] =
          compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
    } else {
      out[offset] = compound_functor.GetOut(x_val, y_val);
    }

    tid += ELEMWISE_MAX_BLOCK_DIM;
  }
}

template <typename T, typename CompoundFunctor, bool BcastY,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x,
                                              const T *y, int pre, int n,
                                              int post,
                                              CompoundFunctor compound_functor,
                                              T *out, T *intermediate_out) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;

  FusedElemwiseAndActBroadcast2CUDAKernel<
      T, CompoundFunctor, BcastY, KeepIntermediateOut,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, compound_functor, pre, n, post, out, intermediate_out);
}

#endif

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::Tensor &x, const framework::Tensor &y,
    CompoundFunctor compound_functor, framework::Tensor *out,
    framework::Tensor *intermediate_out) {
  size_t N = static_cast<size_t>(framework::product(x_dim));

  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);

  for_range(
      FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
          x.data<T>(), y.data<T>(), compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool BcastY, bool KeepIntermediateOut,
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
    const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
    framework::Tensor *out, framework::Tensor *intermediate_out) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);

  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), compound_functor, h, w,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
                                        KeepIntermediateOut,
                                        SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
                                       KeepIntermediateOut,
                                       SameShapeOfIntermediateOutAndOut>(
          x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
          out->mutable_data<T>(ctx.GetPlace()),
          intermediate_out == nullptr
              ? nullptr
              : intermediate_out->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

// --- backward
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i],
                                           out_[i], dout_[i])
                                  : dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
      dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i],
                                           out_[i], dout_[i])
                                  : dy_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
  }

  const T *x_;
  const T *y_;
  const T *intermediate_out_;
  const T *out_;
  const T *dout_;
  DX_OP dx_op_;
  DY_OP dy_op_;
  T *dx_;
  T *dy_;
};

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
          bool UseIntermediateOut>
void FusedElemwiseAndActGradComputeNoBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
    framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
  size_t N = static_cast<size_t>(framework::product(x_dim));
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
  for_range(
      FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, UseIntermediateOut>{
          x->data<T>(), y->data<T>(),
          intermediate_out ? intermediate_out->data<T>() : nullptr,
          out->data<T>(), dout->data<T>(), dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
                                                 const T *intermediate_out,
                                                 const T *out, const T *dout,
                                                 int h, int w, DX_OP dx_op,
                                                 DY_OP dy_op, T *dx, T *dy) {
  int64_t tmp_out_idx, x_idx, y_idx;
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int offset = i * w + j;

      tmp_out_idx = BcastY ? j : offset;
      y_idx = BcastY ? j : offset;
      x_idx = BcastY ? offset : j;

      if (SameShapeOfIntermediateOutAndOut) {
        tmp_out_idx = offset;
      }

      if (dx != nullptr) {
        T tmp = UseIntermediateOut
                    ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                            out[offset], dout[offset])
                    : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);

        if (BcastY) {
          dx[x_idx] = tmp;
        } else {
          if (i == 0) {
            dx[x_idx] = tmp;
          } else {
            dx[x_idx] += tmp;
          }
        }
      }
      if (dy != nullptr) {
        T tmp = UseIntermediateOut
                    ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                            out[offset], dout[offset])
                    : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
        if (BcastY) {
          if (i == 0) {
            dy[y_idx] = tmp;
          } else {
            dy[y_idx] += tmp;
          }
        } else {
          dy[y_idx] = tmp;
        }
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
                                                 const T *intermediate_out,
                                                 const T *out, const T *dout,
                                                 int pre, int n, int post,
                                                 DX_OP dx_op, DY_OP dy_op,
                                                 T *dx, T *dy) {
  int64_t tmp_out_idx, x_idx, y_idx;
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int offset = i * n * post + j * post + k;

        tmp_out_idx = BcastY ? j : offset;
        y_idx = BcastY ? j : offset;
        x_idx = BcastY ? offset : j;

        if (SameShapeOfIntermediateOutAndOut) {
          tmp_out_idx = offset;
        }

        if (dx != nullptr) {
          T tmp = UseIntermediateOut
                      ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                              out[offset], dout[offset])
                      : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);

          if (BcastY) {
            dx[x_idx] = tmp;
          } else {
            if (i == 0 && k == 0) {
              dx[x_idx] = tmp;
            } else {
              dx[x_idx] += tmp;
            }
          }
        }
        if (dy != nullptr) {
          T tmp = UseIntermediateOut
                      ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                              out[offset], dout[offset])
                      : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
          if (BcastY) {
            if (i == 0 && k == 0) {
              dy[y_idx] = tmp;
            } else {
              dy[y_idx] += tmp;
            }
          } else {
            dy[y_idx] = tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
  T val(0);
  int64_t tmp_out_idx, x_idx, y_idx;

  do {
    int offset = i * w + j;

    tmp_out_idx = BcastY ? j : offset;
    y_idx = BcastY ? j : offset;
    x_idx = BcastY ? offset : j;

    if (SameShapeOfIntermediateOutAndOut) {
      tmp_out_idx = offset;
    }

    if (dx != nullptr) {
      T tmp = UseIntermediateOut
                  ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                          out[offset], dout[offset])
                  : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);

      if (BcastY) {
        dx[x_idx] = tmp;
      } else {
        val += tmp;
      }
    }
    if (dy != nullptr) {
      T tmp = UseIntermediateOut
                  ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                          out[offset], dout[offset])
                  : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
      if (BcastY) {
        val += tmp;
      } else {
        dy[y_idx] = tmp;
      }
    }

    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (BcastY) {
    if (dy) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    if (dx) {
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream,
                                                  const T *x, const T *y,
                                                  const T *intermediate_out,
                                                  const T *out, const T *dout,
                                                  int h, int w, DX_OP dx_op,
                                                  DY_OP dy_op, T *dx, T *dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
  FusedElemwiseAndActGradBroadcast1CUDAKernel<
      T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy);
}

template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
    const T *x, const T *y, const T *intermediate_out, const T *out,
    const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx,
    T *dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

  T val(0);
  int ttid = tid;
  int64_t tmp_out_idx, x_idx, y_idx;
  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int offset = i * n * post + j * post + k;

    tmp_out_idx = BcastY ? j : offset;
    y_idx = BcastY ? j : offset;
    x_idx = BcastY ? offset : j;

    if (SameShapeOfIntermediateOutAndOut) {
      tmp_out_idx = offset;
    }

    if (dx != nullptr) {
      T tmp = UseIntermediateOut
                  ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                          out[offset], dout[offset])
                  : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);

      if (BcastY) {
        dx[x_idx] = tmp;
      } else {
        val += tmp;
      }
    }
    if (dy != nullptr) {
      T tmp = UseIntermediateOut
                  ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
                          out[offset], dout[offset])
                  : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
      if (BcastY) {
        val += tmp;
      } else {
        dy[y_idx] = tmp;
      }
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (BcastY) {
    if (dy) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dy[j] = val;
      }
    }
  } else {
    if (dx) {
      int h = pre * post;
      h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
      val = paddle::platform::reduceSum(val, tid, h);
      if (threadIdx.x == 0) {
        dx[j] = val;
      }
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
          bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CUDA(
    cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
    const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
    DY_OP dy_op, T *dx, T *dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
  FusedElemwiseAndActGradBroadcast2CUDAKernel<
      T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
      SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
      x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
          bool UseIntermediateOut, bool BcastY,
          bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
    const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
    const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *intermediate_out,
    const framework::Tensor *out, const framework::Tensor *dout, int axis,
    framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
                                            BcastY,
                                            SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
          y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, UseIntermediateOut,
                                           BcastY,
                                           SameShapeOfIntermediateOutAndOut>(
          x->data<T>(), y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
                                            BcastY,
                                            SameShapeOfIntermediateOutAndOut>(
          ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
          y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, UseIntermediateOut,
                                           BcastY,
                                           SameShapeOfIntermediateOutAndOut>(
          x->data<T>(), y->data<T>(),
          intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
          out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
          bool UseIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeEx(
    const framework::ExecutionContext &ctx, const framework::Tensor *x,
    const framework::Tensor *y, const framework::Tensor *out,
    const framework::Tensor *intermediate_out, const framework::Tensor *dout,
    int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op,
    DY_OP dy_op) {
  const framework::DDim &x_dim = x->dims();
  const framework::DDim &y_dim = y->dims();
  if (UseIntermediateOut) {
    PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
  }
  if (x_dim == y_dim) {
    FusedElemwiseAndActGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP,
                                              UseIntermediateOut>(
        ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
        dx_op, dy_op);
  } else {  // Y is a scalar
    bool bcast_y = x_dim.size() >= y_dim.size();
    if (x_dim.size() == y_dim.size()) {
      for (int i = 0; i < x_dim.size(); ++i) {
        if (x_dim[i] < y_dim[i]) {
          bcast_y = false;
          break;
        }
      }
    }

    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      FusedElemwiseAndActGradComputeWithBroadcast<
          DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/,
          SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y,
                                            intermediate_out, out, dout, axis,
                                            dx, dy, dx_op, dy_op);
    } else {
      FusedElemwiseAndActGradComputeWithBroadcast<
          DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/,
          SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y,
                                            intermediate_out, out, dout, axis,
                                            dx, dy, dx_op, dy_op);
    }
  }
}

template <typename DeviceContext, typename T, typename CompoundFunctor,
          bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
                                  const framework::Tensor &x,
                                  const framework::Tensor &y, int axis,
                                  CompoundFunctor compound_functor,
                                  framework::Tensor *out,
                                  framework::Tensor *intermediate_out) {
  if (KeepIntermediateOut) {
    PADDLE_ENFORCE(intermediate_out,
                   "The keep_intermediate_value is opened, "
                   "intermediate_out should not be nullptr.");
  }

  const framework::DDim &x_dim = x.dims();
  const framework::DDim &y_dim = y.dims();
  if (x.dims() == y.dims()) {
    FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
                                          KeepIntermediateOut>(
        ctx, x_dim, x, y, compound_functor, out, intermediate_out);
  } else {
    // Whether the shape of Y is a continuous subsequence of X,
    // For more information please refer to the op's introduction.
    bool bcast_y = x.dims().size() >= y.dims().size();
    if (x.dims().size() == y.dims().size()) {
      for (int i = 0; i < x.dims().size(); ++i) {
        if (x.dims()[i] < y.dims()[i]) {
          bcast_y = false;
          break;
        }
      }
    }

    // z = f1(x, f2(y))
    // z = f1(f2(x, y))
    if (bcast_y) {  // Y should be broadcast.
      // In this case,
      // for 'f2(y)', the shape of intermediate_out should be equal to the shape
      // of Y.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of X.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, true /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    } else {
      // In this case,
      // for 'f2(y)', the shape of intermediate_out should be equal to the shape
      // of Out.
      // for 'f2(x, y)', the shape of intermediate_out should be equal to the
      // shape of Out.
      // the shape of Out should be equal to the shape of Y.
      FusedElemwiseAndActComputeWithBroadcast<
          DeviceContext, T, CompoundFunctor, false /*BcastY*/,
          KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
          ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
          intermediate_out);
    }
  }
}
1488 1489
}  // namespace operators
}  // namespace paddle