elementwise_op_function.h 21.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <algorithm>
Y
Yi Wang 已提交
17 18 19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
21

C
chengduoZH 已提交
22
#ifdef __NVCC__
23
#include <cuda.h>
C
chengduoZH 已提交
24
#include <thrust/iterator/iterator_adaptor.h>
D
dzhwinter 已提交
25
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
26
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
27 28
#endif

Y
Yi Wang 已提交
29
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
30
#include "paddle/fluid/platform/for_range.h"
31 32 33 34 35 36 37 38 39 40

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
41
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
42 43
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
44
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
45 46 47
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
48 49 50 51
                         int* pre, int* n, int* post) {
  *pre = 1;
  *n = 1;
  *post = 1;
52
  for (int i = 0; i < axis; ++i) {
53
    (*pre) *= x_dims[i];
54 55 56 57 58
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
59
    (*n) *= y_dims[i];
60 61 62
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
63
    (*post) *= x_dims[i];
64 65 66
  }
}

67
inline void trim_trailing_singular_dims(framework::DDim* dims) {
68
  // Remove trailing dimensions of size 1 for y
69
  auto actual_dims_size = dims->size();
70
  for (; actual_dims_size != 0; --actual_dims_size) {
71
    if ((*dims)[actual_dims_size - 1] != 1) break;
72
  }
73 74
  if (actual_dims_size != dims->size()) {
    auto actual_dims = framework::vectorize(*dims);
75
    actual_dims.resize(actual_dims_size);
76
    *dims = framework::make_ddim(actual_dims);
77 78 79
  }
}

Q
QI JUN 已提交
80
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
81
class RowwiseTransformIterator;
Q
QI JUN 已提交
82
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
83
class MidWiseTransformIterator;
C
chengduoZH 已提交
84 85

template <typename T>
Q
QI JUN 已提交
86
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
87
 public:
C
chengduoZH 已提交
88 89
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
90
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
91
    ++i_;
C
chengduoZH 已提交
92 93 94
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
95 96 97
    return *this;
  }

Q
QI JUN 已提交
98 99
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
100
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
101 102
  }

Q
QI JUN 已提交
103 104
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
105
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
106 107 108 109
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
110
 private:
C
chengduoZH 已提交
111 112
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
113
  int64_t n_;
C
chengduoZH 已提交
114 115 116
};

template <typename T>
Q
QI JUN 已提交
117
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
118
 public:
C
chengduoZH 已提交
119 120 121
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
122
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
123
    ++j_;
C
chengduoZH 已提交
124 125
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
126
      j_ = 0;
C
chengduoZH 已提交
127 128 129
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
130
    }
C
chengduoZH 已提交
131 132 133
    return *this;
  }

Q
QI JUN 已提交
134 135
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
136
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
137 138
  }

Q
QI JUN 已提交
139 140
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
141
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
142 143 144 145
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
146
 private:
C
chengduoZH 已提交
147
  const T* ptr_;
C
refine  
chengduoZH 已提交
148
  int64_t i_;
C
chengduoZH 已提交
149 150
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
151
  int64_t post_;
C
chengduoZH 已提交
152 153
};

C
chengduoZH 已提交
154 155
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
156
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
157
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
158
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
159 160
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
161
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
162
      super_t;
C
chengduoZH 已提交
163
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
164
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
165 166 167 168 169
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
170
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
171 172 173 174 175
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
176
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
177
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
178
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
179 180
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
181
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
182
      super_t;
C
chengduoZH 已提交
183
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
184
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
185 186 187 188 189 190
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
191
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
192 193 194 195 196
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

197 198
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
199 200
class TransformFunctor {
 public:
C
chengduoZH 已提交
201
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
202
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
203 204
      : x_(x->data<T>()),
        y_(y->data<T>()),
205
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
206 207 208 209 210
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
211
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
212
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
213 214 215
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
216 217 218
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
219 220 221
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
222 223 224
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
225 226
  }

C
chengduoZH 已提交
227
 private:
C
chengduoZH 已提交
228 229
  const T* x_;
  const T* y_;
230
  OutType* z_;
C
chengduoZH 已提交
231
  int64_t nx_;
Q
QI JUN 已提交
232
  const DeviceContext& ctx_;
C
chengduoZH 已提交
233 234 235
  Functor func_;
};

236 237
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
238
    template <typename DeviceContext, typename T>                              \
239 240 241 242 243 244
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
245 246 247
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
248
    }                                                                          \
Q
QI JUN 已提交
249
    template <typename DeviceContext, typename T>                              \
250 251 252 253 254 255 256 257 258 259
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
260 261 262
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
263
    }                                                                          \
Q
QI JUN 已提交
264
    template <typename DeviceContext, typename T>                              \
265 266 267 268 269 270 271 272 273 274 275
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
276 277 278
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
294 295 296 297 298 299 300 301 302 303 304 305
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
306
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
337

D
dzhwinter 已提交
338
#ifdef __NVCC__
339 340 341

template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
D
dzhwinter 已提交
342
  // NOTE(zcd): The warp size should be taken from the
343 344 345 346 347 348
  // parameters of the GPU but not specified as 32 simply.
  // To make the reduceSum more efficiently,
  // I use Warp-Level Parallelism and assume the Warp size
  // is 32 which may be different for different GPU,
  // but most card's warp size is 32.
  const int warpSize = 32;
349
  __shared__ T shm[warpSize];
350 351 352 353
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
D
dzhwinter 已提交
354
    val += platform::__shfl_down_sync(mask, val, offset);
355 356 357 358 359 360 361 362

  if (tid < warpSize) shm[tid] = 0;

  __syncthreads();

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }
363
  __syncthreads();
364 365 366 367 368 369

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
D
dzhwinter 已提交
370
      val += platform::__shfl_down_sync(mask, val, offset);
371 372 373 374 375
  }

  return val;
}

Y
Yu Yang 已提交
376 377 378 379 380 381 382
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
C
chengduoZH 已提交
383
  T val = 0;
Y
Yu Yang 已提交
384 385 386 387 388 389 390

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
391
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
392 393 394 395 396
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
397
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
398
    val = reduceSum(val, tid, h);
Y
Yu Yang 已提交
399
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
400
      dy[j] = val;
Y
Yu Yang 已提交
401 402 403 404 405 406 407 408 409 410 411
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
412 413
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduoZH 已提交
451
  T val = 0;
Y
Yu Yang 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
466
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
467 468 469 470 471 472
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
473 474
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
475
    val = reduceSum(val, tid, h);
C
chengduoZH 已提交
476
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
477
      dy[j] = val;
Y
Yu Yang 已提交
478 479 480 481 482 483 484 485 486 487 488
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
489 490
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
}

#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
  if (x.dims() == y.dims()) {
    size_t N = static_cast<size_t>(framework::product(x.dims()));
    platform::ForRange<DeviceContext> for_range(
        ctx.template device_context<DeviceContext>(), N);
    for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
        x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
        dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
        dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
  } else {  // Y is a scalar
    auto x_dim = x.dims();
    auto y_dim = y.dims();

    axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
515
    trim_trailing_singular_dims(&y_dim);
516 517
    axis = (y_dim.size() == 0) ? x_dim.size() : axis;

Y
Yu Yang 已提交
518
    int pre, n, post;
519
    get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
Y
Yu Yang 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
    if (post == 1) {
      int h = pre;
      int w = n;
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast1CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast1CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
            dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    } else {
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast2CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
            dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast2CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
            post, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    }
  }
557
}
Y
Yu Yang 已提交
558

Q
QI JUN 已提交
559
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
560
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
561 562 563 564 565 566
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
567
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
586
  trim_trailing_singular_dims(&y_dims);
587
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
588 589

  int pre, n, post;
590
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
591 592 593 594 595 596 597 598 599 600 601

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
602

603 604
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
605 606
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
607
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
608
                          framework::Tensor* z) {
609
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
610
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623 624

  auto x_dims = x->dims();
  auto y_dims = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
                    "Rank of first input must >= rank of second input.");

  if (x_dims == y_dims) {
    functor.Run();
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
625
  trim_trailing_singular_dims(&y_dims);
626
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
627 628

  int pre, n, post;
629
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
630 631 632 633 634 635 636 637 638
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

639 640
}  // namespace operators
}  // namespace paddle