elementwise_op_function.h 24.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
D
dzhwinter 已提交
16

17
#include <glog/logging.h>
18
#include <algorithm>
D
dzhwinter 已提交
19
#include <iterator>
20
#include <vector>
Y
Yi Wang 已提交
21 22 23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
25

C
chengduoZH 已提交
26
#ifdef __NVCC__
27
#include <cuda.h>
C
chengduoZH 已提交
28
#include <thrust/iterator/iterator_adaptor.h>
29
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
30
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
31
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
32 33
#endif

Y
Yi Wang 已提交
34
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
35
#include "paddle/fluid/platform/for_range.h"
36 37 38 39 40 41 42 43 44 45

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
46
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
47 48
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
49
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
50 51 52
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
53 54 55 56
                         int* pre, int* n, int* post) {
  *pre = 1;
  *n = 1;
  *post = 1;
57
  for (int i = 0; i < axis; ++i) {
58
    (*pre) *= x_dims[i];
59 60 61 62 63
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
64
    (*n) *= y_dims[i];
65 66 67
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
68
    (*post) *= x_dims[i];
69 70 71
  }
}

72 73
inline framework::DDim trim_trailing_singular_dims(
    const framework::DDim& dims) {
74
  // Remove trailing dimensions of size 1 for y
75
  auto actual_dims_size = dims.size();
76
  for (; actual_dims_size != 0; --actual_dims_size) {
77
    if (dims[actual_dims_size - 1] != 1) break;
78
  }
79 80 81 82 83

  std::vector<int> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
84
  }
85 86 87
  if (trim_dims.size() == 0) {
    return framework::DDim(framework::make_dim());
  }
88 89
  framework::DDim actual_dims = framework::make_ddim(trim_dims);
  return actual_dims;
90 91
}

Q
QI JUN 已提交
92
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
93
class RowwiseTransformIterator;
Q
QI JUN 已提交
94
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
95
class MidWiseTransformIterator;
C
chengduoZH 已提交
96

D
dzhwinter 已提交
97
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
C
chengduoZH 已提交
98
template <typename T>
D
dzhwinter 已提交
99 100 101
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
    : public std::iterator<std::random_access_iterator_tag, typename T,
                           std::ptrdiff_t, typename T*, typename T&> {
C
chengduoZH 已提交
102
 public:
C
chengduoZH 已提交
103 104
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
105
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
106
    ++i_;
C
chengduoZH 已提交
107 108 109
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
110 111 112
    return *this;
  }

Q
QI JUN 已提交
113 114
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
115
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
116 117
  }

Q
QI JUN 已提交
118 119
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
120
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
121 122 123 124
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
125
 private:
C
chengduoZH 已提交
126 127
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
128
  int64_t n_;
C
chengduoZH 已提交
129 130 131
};

template <typename T>
D
dzhwinter 已提交
132 133 134
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
    : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
                           T*, T&> {
C
chengduoZH 已提交
135
 public:
C
chengduoZH 已提交
136 137 138
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
139
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
140
    ++j_;
C
chengduoZH 已提交
141 142
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
143
      j_ = 0;
C
chengduoZH 已提交
144 145 146
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
147
    }
C
chengduoZH 已提交
148 149 150
    return *this;
  }

Q
QI JUN 已提交
151 152
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
153
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
154 155
  }

Q
QI JUN 已提交
156 157
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
158
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
159 160 161 162
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
163
 private:
C
chengduoZH 已提交
164
  const T* ptr_;
C
refine  
chengduoZH 已提交
165
  int64_t i_;
C
chengduoZH 已提交
166 167
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
168
  int64_t post_;
C
chengduoZH 已提交
169 170
};

C
chengduoZH 已提交
171 172
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
173
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
174
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
175
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
176 177
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
178
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
179
      super_t;
C
chengduoZH 已提交
180
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
181
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
182 183 184 185 186
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
187
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
188 189 190 191 192
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
193
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
194
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
195
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
196 197
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
198
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
199
      super_t;
C
chengduoZH 已提交
200
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
201
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
202 203 204 205 206 207
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
208
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
209 210 211 212 213
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

214 215
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
216 217
class TransformFunctor {
 public:
C
chengduoZH 已提交
218
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
219
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
220 221
      : x_(x->data<T>()),
        y_(y->data<T>()),
222
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
223 224 225 226 227
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
228
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
229
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
230 231 232
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
233 234 235
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
236 237 238
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
239 240 241
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
242 243
  }

C
chengduoZH 已提交
244
 private:
C
chengduoZH 已提交
245 246
  const T* x_;
  const T* y_;
247
  OutType* z_;
C
chengduoZH 已提交
248
  int64_t nx_;
Q
QI JUN 已提交
249
  const DeviceContext& ctx_;
C
chengduoZH 已提交
250 251 252
  Functor func_;
};

253 254
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
255
    template <typename DeviceContext, typename T>                              \
256 257 258 259 260 261
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
262 263 264
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
265
    }                                                                          \
Q
QI JUN 已提交
266
    template <typename DeviceContext, typename T>                              \
267 268 269 270 271 272 273 274 275 276
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
277 278 279
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
280
    }                                                                          \
Q
QI JUN 已提交
281
    template <typename DeviceContext, typename T>                              \
282 283 284 285 286 287 288 289 290 291 292
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
293 294 295
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
311 312 313 314 315 316 317 318 319 320 321 322
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
323
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
354

D
dzhwinter 已提交
355
#ifdef __NVCC__
Y
Yu Yang 已提交
356 357 358 359 360 361 362
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
363
  T val = 0;
Y
Yu Yang 已提交
364 365 366 367 368 369 370

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
371
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
372 373 374 375 376
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
377
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
378
    val = paddle::platform::reduceSum(val, tid, h);
Y
Yu Yang 已提交
379
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
380
      dy[j] = val;
Y
Yu Yang 已提交
381 382 383 384 385 386 387 388 389 390 391
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
392 393
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

431
  T val = 0;
Y
Yu Yang 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
446
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
447 448 449 450 451 452
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
453 454
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
455
    val = paddle::platform::reduceSum(val, tid, h);
C
chengduoZH 已提交
456
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
457
      dy[j] = val;
Y
Yu Yang 已提交
458 459 460 461 462 463 464 465 466 467 468
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
469 470
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
471 472 473 474
}

#endif

475 476 477 478 479 480 481 482
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
    const framework::DDim& y_dim, const framework::Tensor& x,
    const framework::Tensor& y, const framework::Tensor& out,
    const framework::Tensor& dout, int axis, framework::Tensor* dx,
    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
  size_t N = static_cast<size_t>(framework::product(x_dim));
D
dzhwinter 已提交
483
#if !defined(_WIN32)
484 485
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
D
dzhwinter 已提交
486 487 488 489
#else
  platform::ForRange<DeviceContext> for_range(
      ctx.device_context<DeviceContext>(), N);
#endif  // !_WIN32
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
    const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
    const framework::Tensor& y, const framework::Tensor& out,
    const framework::Tensor& dout, int axis, framework::Tensor* dx,
    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast1CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast1CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast2CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast2CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
          dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

Y
Yu Yang 已提交
545 546 547 548 549 550 551
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
D
dzhwinter 已提交
552 553
  const framework::DDim& x_dim = x.dims();
  const framework::DDim& y_dim = y.dims();
Y
Yu Yang 已提交
554
  if (x.dims() == y.dims()) {
555 556
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
557
  } else {  // Y is a scalar
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
                                 const framework::Tensor& x,
                                 const framework::Tensor& y,
                                 const framework::Tensor& out,
                                 const framework::Tensor& dout, int axis,
                                 framework::Tensor* dx, framework::Tensor* dy,
                                 DX_OP dx_op, DY_OP dy_op) {
  if (dy == nullptr) {
D
dzhwinter 已提交
576
    const framework::DDim& dx_dims = dout.dims();
577 578 579 580 581
    auto dy_dims = dx_dims;
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  } else {
    if (dout.dims() == dy->dims()) {
D
dzhwinter 已提交
582 583
      const framework::DDim& dx_dims = dout.dims();
      const framework::DDim& dy_dims = dy->dims();
584 585 586 587
      ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    } else {  // Y is a scalar
      auto dx_dims = dout.dims();
D
dzhwinter 已提交
588
      const framework::DDim& dy_dims = dy->dims();
589 590
      ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
591 592
    }
  }
593
}
Y
Yu Yang 已提交
594

595
// Deprecated
Q
QI JUN 已提交
596
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
597
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
598 599 600 601 602 603
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
604
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
623
  trim_trailing_singular_dims(y_dims);
624
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
625 626

  int pre, n, post;
627
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
628 629 630 631 632 633 634 635 636 637 638

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
639

640 641
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
642 643
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
644
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
645
                          framework::Tensor* z) {
D
dzhwinter 已提交
646
#if !defined(_WIN32)
647
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
648
      x, y, z, ctx.template device_context<DeviceContext>(), func);
D
dzhwinter 已提交
649 650 651 652
#else
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
      x, y, z, ctx.device_context<DeviceContext>(), func);
#endif  // !_WIN32
F
fengjiayi 已提交
653
  auto x_dims = x->dims();
654 655
  auto y_dims_untrimed = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
F
fengjiayi 已提交
656 657
                    "Rank of first input must >= rank of second input.");

658
  if (x_dims == y_dims_untrimed) {
F
fengjiayi 已提交
659 660 661 662
    functor.Run();
    return;
  }

663
  axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
F
fengjiayi 已提交
664 665
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
666
  auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
667
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
668 669

  int pre, n, post;
670
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
671 672 673 674 675 676 677 678 679
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

680 681
}  // namespace operators
}  // namespace paddle