elementwise_op_function.h 23.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <glog/logging.h>
17
#include <algorithm>
18
#include <vector>
Y
Yi Wang 已提交
19 20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
23

C
chengduoZH 已提交
24
#ifdef __NVCC__
25
#include <cuda.h>
C
chengduoZH 已提交
26
#include <thrust/iterator/iterator_adaptor.h>
27
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
28
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
29
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
30 31
#endif

Y
Yi Wang 已提交
32
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
33
#include "paddle/fluid/platform/for_range.h"
34 35 36 37 38 39 40 41 42 43

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
44
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
45 46
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
47
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
48 49 50
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
51 52 53 54
                         int* pre, int* n, int* post) {
  *pre = 1;
  *n = 1;
  *post = 1;
55
  for (int i = 0; i < axis; ++i) {
56
    (*pre) *= x_dims[i];
57 58 59 60 61
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
62
    (*n) *= y_dims[i];
63 64 65
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
66
    (*post) *= x_dims[i];
67 68 69
  }
}

70 71
inline framework::DDim trim_trailing_singular_dims(
    const framework::DDim& dims) {
72
  // Remove trailing dimensions of size 1 for y
73
  auto actual_dims_size = dims.size();
74
  for (; actual_dims_size != 0; --actual_dims_size) {
75
    if (dims[actual_dims_size - 1] != 1) break;
76
  }
77 78 79 80 81

  std::vector<int> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
82
  }
83 84 85
  if (trim_dims.size() == 0) {
    return framework::DDim(framework::make_dim());
  }
86 87
  framework::DDim actual_dims = framework::make_ddim(trim_dims);
  return actual_dims;
88 89
}

Q
QI JUN 已提交
90
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
91
class RowwiseTransformIterator;
Q
QI JUN 已提交
92
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
93
class MidWiseTransformIterator;
C
chengduoZH 已提交
94 95

template <typename T>
Q
QI JUN 已提交
96
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
97
 public:
C
chengduoZH 已提交
98 99
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
100
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
101
    ++i_;
C
chengduoZH 已提交
102 103 104
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
105 106 107
    return *this;
  }

Q
QI JUN 已提交
108 109
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
110
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
111 112
  }

Q
QI JUN 已提交
113 114
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
115
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
116 117 118 119
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
120
 private:
C
chengduoZH 已提交
121 122
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
123
  int64_t n_;
C
chengduoZH 已提交
124 125 126
};

template <typename T>
Q
QI JUN 已提交
127
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
128
 public:
C
chengduoZH 已提交
129 130 131
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
132
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
133
    ++j_;
C
chengduoZH 已提交
134 135
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
136
      j_ = 0;
C
chengduoZH 已提交
137 138 139
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
140
    }
C
chengduoZH 已提交
141 142 143
    return *this;
  }

Q
QI JUN 已提交
144 145
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
146
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
147 148
  }

Q
QI JUN 已提交
149 150
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
151
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
152 153 154 155
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
156
 private:
C
chengduoZH 已提交
157
  const T* ptr_;
C
refine  
chengduoZH 已提交
158
  int64_t i_;
C
chengduoZH 已提交
159 160
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
161
  int64_t post_;
C
chengduoZH 已提交
162 163
};

C
chengduoZH 已提交
164 165
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
166
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
167
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
168
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
169 170
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
171
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
172
      super_t;
C
chengduoZH 已提交
173
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
174
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
175 176 177 178 179
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
180
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
181 182 183 184 185
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
186
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
187
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
188
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
189 190
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
191
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
192
      super_t;
C
chengduoZH 已提交
193
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
194
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
195 196 197 198 199 200
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
201
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
202 203 204 205 206
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

207 208
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
209 210
class TransformFunctor {
 public:
C
chengduoZH 已提交
211
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
212
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
213 214
      : x_(x->data<T>()),
        y_(y->data<T>()),
215
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
216 217 218 219 220
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
221
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
222
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
223 224 225
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
226 227 228
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
229 230 231
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
232 233 234
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
235 236
  }

C
chengduoZH 已提交
237
 private:
C
chengduoZH 已提交
238 239
  const T* x_;
  const T* y_;
240
  OutType* z_;
C
chengduoZH 已提交
241
  int64_t nx_;
Q
QI JUN 已提交
242
  const DeviceContext& ctx_;
C
chengduoZH 已提交
243 244 245
  Functor func_;
};

246 247
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
248
    template <typename DeviceContext, typename T>                              \
249 250 251 252 253 254
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
255 256 257
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
258
    }                                                                          \
Q
QI JUN 已提交
259
    template <typename DeviceContext, typename T>                              \
260 261 262 263 264 265 266 267 268 269
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
270 271 272
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
273
    }                                                                          \
Q
QI JUN 已提交
274
    template <typename DeviceContext, typename T>                              \
275 276 277 278 279 280 281 282 283 284 285
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
286 287 288
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
304 305 306 307 308 309 310 311 312 313 314 315
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
316
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
347

D
dzhwinter 已提交
348
#ifdef __NVCC__
Y
Yu Yang 已提交
349 350 351 352 353 354 355
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
356
  T val = 0;
Y
Yu Yang 已提交
357 358 359 360 361 362 363

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
364
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
365 366 367 368 369
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
370
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
371
    val = paddle::platform::reduceSum(val, tid, h);
Y
Yu Yang 已提交
372
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
373
      dy[j] = val;
Y
Yu Yang 已提交
374 375 376 377 378 379 380 381 382 383 384
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
385 386
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

424
  T val = 0;
Y
Yu Yang 已提交
425 426 427 428 429 430 431 432 433 434 435 436 437 438
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
439
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
440 441 442 443 444 445
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
446 447
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
448
    val = paddle::platform::reduceSum(val, tid, h);
C
chengduoZH 已提交
449
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
450
      dy[j] = val;
Y
Yu Yang 已提交
451 452 453 454 455 456 457 458 459 460 461
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
462 463
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
464 465 466 467
}

#endif

468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
    const framework::DDim& y_dim, const framework::Tensor& x,
    const framework::Tensor& y, const framework::Tensor& out,
    const framework::Tensor& dout, int axis, framework::Tensor* dx,
    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
  size_t N = static_cast<size_t>(framework::product(x_dim));
  platform::ForRange<DeviceContext> for_range(
      ctx.template device_context<DeviceContext>(), N);
  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
    const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
    const framework::Tensor& y, const framework::Tensor& out,
    const framework::Tensor& dout, int axis, framework::Tensor* dx,
    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
  axis = (y_dim.size() == 0) ? x_dim.size() : axis;

  int pre, n, post;
  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
  if (post == 1) {
    int h = pre;
    int w = n;
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast1CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast1CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  } else {
    if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
      ElemwiseGradBroadcast2CUDA(
          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
    } else {
      ElemwiseGradBroadcast2CPU(
          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
          dx_op, dy_op,
          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
    }
  }
}

Y
Yu Yang 已提交
533 534 535 536 537 538 539
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
D
dzhwinter 已提交
540 541
  const framework::DDim& x_dim = x.dims();
  const framework::DDim& y_dim = y.dims();
Y
Yu Yang 已提交
542
  if (x.dims() == y.dims()) {
543 544
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
545
  } else {  // Y is a scalar
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  }
}

// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
                                 const framework::Tensor& x,
                                 const framework::Tensor& y,
                                 const framework::Tensor& out,
                                 const framework::Tensor& dout, int axis,
                                 framework::Tensor* dx, framework::Tensor* dy,
                                 DX_OP dx_op, DY_OP dy_op) {
  if (dy == nullptr) {
D
dzhwinter 已提交
564
    const framework::DDim& dx_dims = dout.dims();
565 566 567 568 569
    auto dy_dims = dx_dims;
    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
        ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
  } else {
    if (dout.dims() == dy->dims()) {
D
dzhwinter 已提交
570 571
      const framework::DDim& dx_dims = dout.dims();
      const framework::DDim& dy_dims = dy->dims();
572 573 574 575
      ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
    } else {  // Y is a scalar
      auto dx_dims = dout.dims();
D
dzhwinter 已提交
576
      const framework::DDim& dy_dims = dy->dims();
577 578
      ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
Y
Yu Yang 已提交
579 580
    }
  }
581
}
Y
Yu Yang 已提交
582

583
// Deprecated
Q
QI JUN 已提交
584
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
585
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
586 587 588 589 590 591
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
592
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
611
  trim_trailing_singular_dims(y_dims);
612
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
613 614

  int pre, n, post;
615
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
616 617 618 619 620 621 622 623 624 625 626

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
627

628 629
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
630 631
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
632
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
633
                          framework::Tensor* z) {
634
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
635
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
636 637

  auto x_dims = x->dims();
638 639
  auto y_dims_untrimed = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
F
fengjiayi 已提交
640 641
                    "Rank of first input must >= rank of second input.");

642
  if (x_dims == y_dims_untrimed) {
F
fengjiayi 已提交
643 644 645 646
    functor.Run();
    return;
  }

647
  axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
F
fengjiayi 已提交
648 649
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
650
  auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
651
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
652 653

  int pre, n, post;
654
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
655 656 657 658 659 660 661 662 663
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

664 665
}  // namespace operators
}  // namespace paddle