elementwise_op_function.h 21.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <algorithm>
Y
Yi Wang 已提交
17 18 19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
21

C
chengduoZH 已提交
22
#ifdef __NVCC__
23
#include <cuda.h>
C
chengduoZH 已提交
24
#include <thrust/iterator/iterator_adaptor.h>
Y
Yu Yang 已提交
25
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
26 27
#endif

Y
Yi Wang 已提交
28
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
29
#include "paddle/fluid/platform/for_range.h"
30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
40
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
41 42
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
43
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
44 45 46
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
47 48 49 50
                         int* pre, int* n, int* post) {
  *pre = 1;
  *n = 1;
  *post = 1;
51
  for (int i = 0; i < axis; ++i) {
52
    (*pre) *= x_dims[i];
53 54 55 56 57
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
58
    (*n) *= y_dims[i];
59 60 61
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
62
    (*post) *= x_dims[i];
63 64 65
  }
}

66
inline void trim_trailing_singular_dims(framework::DDim* dims) {
67
  // Remove trailing dimensions of size 1 for y
68
  auto actual_dims_size = dims->size();
69
  for (; actual_dims_size != 0; --actual_dims_size) {
70
    if ((*dims)[actual_dims_size - 1] != 1) break;
71
  }
72 73
  if (actual_dims_size != dims->size()) {
    auto actual_dims = framework::vectorize(*dims);
74
    actual_dims.resize(actual_dims_size);
75
    *dims = framework::make_ddim(actual_dims);
76 77 78
  }
}

Q
QI JUN 已提交
79
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
80
class RowwiseTransformIterator;
Q
QI JUN 已提交
81
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
82
class MidWiseTransformIterator;
C
chengduoZH 已提交
83 84

template <typename T>
Q
QI JUN 已提交
85
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
86
 public:
C
chengduoZH 已提交
87 88
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
89
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
90
    ++i_;
C
chengduoZH 已提交
91 92 93
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
94 95 96
    return *this;
  }

Q
QI JUN 已提交
97 98
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
99
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
100 101
  }

Q
QI JUN 已提交
102 103
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
104
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
105 106 107 108
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
109
 private:
C
chengduoZH 已提交
110 111
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
112
  int64_t n_;
C
chengduoZH 已提交
113 114 115
};

template <typename T>
Q
QI JUN 已提交
116
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
117
 public:
C
chengduoZH 已提交
118 119 120
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
121
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
122
    ++j_;
C
chengduoZH 已提交
123 124
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
125
      j_ = 0;
C
chengduoZH 已提交
126 127 128
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
129
    }
C
chengduoZH 已提交
130 131 132
    return *this;
  }

Q
QI JUN 已提交
133 134
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
135
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
136 137
  }

Q
QI JUN 已提交
138 139
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
140
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
141 142 143 144
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
145
 private:
C
chengduoZH 已提交
146
  const T* ptr_;
C
refine  
chengduoZH 已提交
147
  int64_t i_;
C
chengduoZH 已提交
148 149
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
150
  int64_t post_;
C
chengduoZH 已提交
151 152
};

C
chengduoZH 已提交
153 154
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
155
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
156
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
157
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
158 159
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
160
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
161
      super_t;
C
chengduoZH 已提交
162
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
163
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
164 165 166 167 168
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
169
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
170 171 172 173 174
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
175
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
176
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
177
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
178 179
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
180
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
181
      super_t;
C
chengduoZH 已提交
182
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
183
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
184 185 186 187 188 189
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
190
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
191 192 193 194 195
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

196 197
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
198 199
class TransformFunctor {
 public:
C
chengduoZH 已提交
200
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
201
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
202 203
      : x_(x->data<T>()),
        y_(y->data<T>()),
204
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
205 206 207 208 209
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
210
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
211
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
212 213 214
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
215 216 217
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
218 219 220
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
221 222 223
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
224 225
  }

C
chengduoZH 已提交
226
 private:
C
chengduoZH 已提交
227 228
  const T* x_;
  const T* y_;
229
  OutType* z_;
C
chengduoZH 已提交
230
  int64_t nx_;
Q
QI JUN 已提交
231
  const DeviceContext& ctx_;
C
chengduoZH 已提交
232 233 234
  Functor func_;
};

235 236
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
237
    template <typename DeviceContext, typename T>                              \
238 239 240 241 242 243
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
244 245 246
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
247
    }                                                                          \
Q
QI JUN 已提交
248
    template <typename DeviceContext, typename T>                              \
249 250 251 252 253 254 255 256 257 258
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
259 260 261
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
262
    }                                                                          \
Q
QI JUN 已提交
263
    template <typename DeviceContext, typename T>                              \
264 265 266 267 268 269 270 271 272 273 274
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
275 276 277
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
293 294 295 296 297 298 299 300 301 302 303 304
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
305
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
#ifdef __NVCC__
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385

// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
  return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
  mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif

template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
  // TODO(zcd): The warp size should be taken from the
  // parameters of the GPU but not specified as 32 simply.
  // To make the reduceSum more efficiently,
  // I use Warp-Level Parallelism and assume the Warp size
  // is 32 which may be different for different GPU,
  // but most card's warp size is 32.
  __shared__ T shm[32];
  const int warpSize = 32;
  unsigned mask = 0u;
  CREATE_SHFL_MASK(mask, tid < len);

  for (int offset = warpSize / 2; offset > 0; offset /= 2)
    val += __shfl_down_sync(mask, val, offset);

  if (tid < warpSize) shm[tid] = 0;

  __syncthreads();

  if (tid % warpSize == 0) {
    shm[tid / warpSize] = val;
  }

  CREATE_SHFL_MASK(mask, tid < warpSize);

  if (tid < warpSize) {
    val = shm[tid];
    for (int offset = warpSize / 2; offset > 0; offset /= 2)
      val += __shfl_down_sync(mask, val, offset);
  }

  return val;
}

Y
Yu Yang 已提交
386 387 388 389 390 391 392
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
C
chengduoZH 已提交
393
  T val = 0;
Y
Yu Yang 已提交
394 395 396 397 398 399 400

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
401
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
402 403 404 405 406
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
407
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
408
    val = reduceSum(val, tid, h);
Y
Yu Yang 已提交
409
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
410
      dy[j] = val;
Y
Yu Yang 已提交
411 412 413 414 415 416 417 418 419 420 421
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
422 423
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduoZH 已提交
461
  T val = 0;
Y
Yu Yang 已提交
462 463 464 465 466 467 468 469 470 471 472 473 474 475
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
476
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
477 478 479 480 481 482
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
483 484
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
485
    val = reduceSum(val, tid, h);
C
chengduoZH 已提交
486
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
487
      dy[j] = val;
Y
Yu Yang 已提交
488 489 490 491 492 493 494 495 496 497 498
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
499 500
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
}

#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
  if (x.dims() == y.dims()) {
    size_t N = static_cast<size_t>(framework::product(x.dims()));
    platform::ForRange<DeviceContext> for_range(
        ctx.template device_context<DeviceContext>(), N);
    for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
        x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
        dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
        dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
  } else {  // Y is a scalar
    auto x_dim = x.dims();
    auto y_dim = y.dims();

    axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
525
    trim_trailing_singular_dims(&y_dim);
526 527
    axis = (y_dim.size() == 0) ? x_dim.size() : axis;

Y
Yu Yang 已提交
528
    int pre, n, post;
529
    get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
Y
Yu Yang 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
    if (post == 1) {
      int h = pre;
      int w = n;
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast1CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast1CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
            dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    } else {
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast2CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
            dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast2CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
            post, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    }
  }
567
}
Y
Yu Yang 已提交
568

Q
QI JUN 已提交
569
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
570
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
571 572 573 574 575 576
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
577
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
596
  trim_trailing_singular_dims(&y_dims);
597
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
598 599

  int pre, n, post;
600
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
601 602 603 604 605 606 607 608 609 610 611

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
612

613 614
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
615 616
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
617
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
618
                          framework::Tensor* z) {
619
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
620
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
621 622 623 624 625 626 627 628 629 630 631 632 633 634

  auto x_dims = x->dims();
  auto y_dims = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
                    "Rank of first input must >= rank of second input.");

  if (x_dims == y_dims) {
    functor.Run();
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
635
  trim_trailing_singular_dims(&y_dims);
636
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
637 638

  int pre, n, post;
639
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
640 641 642 643 644 645 646 647 648
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

649 650
}  // namespace operators
}  // namespace paddle