elementwise_op_function.h 20.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <algorithm>
Y
Yi Wang 已提交
17 18 19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
21

C
chengduoZH 已提交
22
#ifdef __NVCC__
23
#include <cuda.h>
C
chengduoZH 已提交
24
#include <thrust/iterator/iterator_adaptor.h>
25
#include "paddle/fluid/platform/cuda_device_function.h"
D
dzhwinter 已提交
26
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yu Yang 已提交
27
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
28 29
#endif

Y
Yi Wang 已提交
30
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
31
#include "paddle/fluid/platform/for_range.h"
32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
42
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
43 44
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
45
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
46 47 48
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
49 50 51 52
                         int* pre, int* n, int* post) {
  *pre = 1;
  *n = 1;
  *post = 1;
53
  for (int i = 0; i < axis; ++i) {
54
    (*pre) *= x_dims[i];
55 56 57 58 59
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
60
    (*n) *= y_dims[i];
61 62 63
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
64
    (*post) *= x_dims[i];
65 66 67
  }
}

68
inline void trim_trailing_singular_dims(framework::DDim* dims) {
69
  // Remove trailing dimensions of size 1 for y
70
  auto actual_dims_size = dims->size();
71
  for (; actual_dims_size != 0; --actual_dims_size) {
72
    if ((*dims)[actual_dims_size - 1] != 1) break;
73
  }
74 75
  if (actual_dims_size != dims->size()) {
    auto actual_dims = framework::vectorize(*dims);
76
    actual_dims.resize(actual_dims_size);
77
    *dims = framework::make_ddim(actual_dims);
78 79 80
  }
}

Q
QI JUN 已提交
81
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
82
class RowwiseTransformIterator;
Q
QI JUN 已提交
83
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
84
class MidWiseTransformIterator;
C
chengduoZH 已提交
85 86

template <typename T>
Q
QI JUN 已提交
87
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
88
 public:
C
chengduoZH 已提交
89 90
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
91
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
92
    ++i_;
C
chengduoZH 已提交
93 94 95
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
96 97 98
    return *this;
  }

Q
QI JUN 已提交
99 100
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
101
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
102 103
  }

Q
QI JUN 已提交
104 105
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
106
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
107 108 109 110
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
111
 private:
C
chengduoZH 已提交
112 113
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
114
  int64_t n_;
C
chengduoZH 已提交
115 116 117
};

template <typename T>
Q
QI JUN 已提交
118
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
119
 public:
C
chengduoZH 已提交
120 121 122
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
123
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
124
    ++j_;
C
chengduoZH 已提交
125 126
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
127
      j_ = 0;
C
chengduoZH 已提交
128 129 130
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
131
    }
C
chengduoZH 已提交
132 133 134
    return *this;
  }

Q
QI JUN 已提交
135 136
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
137
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
138 139
  }

Q
QI JUN 已提交
140 141
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
142
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
143 144 145 146
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
147
 private:
C
chengduoZH 已提交
148
  const T* ptr_;
C
refine  
chengduoZH 已提交
149
  int64_t i_;
C
chengduoZH 已提交
150 151
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
152
  int64_t post_;
C
chengduoZH 已提交
153 154
};

C
chengduoZH 已提交
155 156
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
157
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
158
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
159
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
160 161
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
162
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
163
      super_t;
C
chengduoZH 已提交
164
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
165
      : super_t(x), begin_(x), n_(n) {}
C
chengduoZH 已提交
166 167 168 169 170
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
171
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
172 173 174 175 176
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
177
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
178
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
179
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
180 181
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
182
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
183
      super_t;
C
chengduoZH 已提交
184
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
185
      : super_t(x), begin_(x), n_(n), post_(post) {}
C
chengduoZH 已提交
186 187 188 189 190 191
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
192
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
193 194 195 196 197
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

198 199
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
200 201
class TransformFunctor {
 public:
C
chengduoZH 已提交
202
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
203
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
204 205
      : x_(x->data<T>()),
        y_(y->data<T>()),
206
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
207 208 209 210 211
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
212
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
213
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
214 215 216
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
217 218 219
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
220 221 222
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
223 224 225
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
226 227
  }

C
chengduoZH 已提交
228
 private:
C
chengduoZH 已提交
229 230
  const T* x_;
  const T* y_;
231
  OutType* z_;
C
chengduoZH 已提交
232
  int64_t nx_;
Q
QI JUN 已提交
233
  const DeviceContext& ctx_;
C
chengduoZH 已提交
234 235 236
  Functor func_;
};

237 238
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
239
    template <typename DeviceContext, typename T>                              \
240 241 242 243 244 245
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
246 247 248
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
249
    }                                                                          \
Q
QI JUN 已提交
250
    template <typename DeviceContext, typename T>                              \
251 252 253 254 255 256 257 258 259 260
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
261 262 263
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
264
    }                                                                          \
Q
QI JUN 已提交
265
    template <typename DeviceContext, typename T>                              \
266 267 268 269 270 271 272 273 274 275 276
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
277 278 279
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    }                                                                          \
  }

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
295 296 297 298 299 300 301 302 303 304 305 306
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
C
chengduoZH 已提交
307
      dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
Y
Yu Yang 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
338

D
dzhwinter 已提交
339
#ifdef __NVCC__
Y
Yu Yang 已提交
340 341 342 343 344 345 346
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
C
chengduoZH 已提交
347
  T val = 0;
Y
Yu Yang 已提交
348 349 350 351 352 353 354

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
C
chengduoZH 已提交
355
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
356 357 358 359 360
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
C
chengduoZH 已提交
361
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
362
    val = paddle::platform::reduceSum(val, tid, h);
Y
Yu Yang 已提交
363
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
364
      dy[j] = val;
Y
Yu Yang 已提交
365 366 367 368 369 370 371 372 373 374 375
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
C
chengduoZH 已提交
376 377
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

C
chengduoZH 已提交
415
  T val = 0;
Y
Yu Yang 已提交
416 417 418 419 420 421 422 423 424 425 426 427 428 429
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
C
chengduoZH 已提交
430
      val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
Y
Yu Yang 已提交
431 432 433 434 435 436
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
C
chengduoZH 已提交
437 438
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
439
    val = paddle::platform::reduceSum(val, tid, h);
C
chengduoZH 已提交
440
    if (threadIdx.x == 0) {
C
chengduoZH 已提交
441
      dy[j] = val;
Y
Yu Yang 已提交
442 443 444 445 446 447 448 449 450 451 452
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
C
chengduoZH 已提交
453 454
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
      x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
Y
Yu Yang 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
}

#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
  if (x.dims() == y.dims()) {
    size_t N = static_cast<size_t>(framework::product(x.dims()));
    platform::ForRange<DeviceContext> for_range(
        ctx.template device_context<DeviceContext>(), N);
    for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
        x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
        dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
        dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
  } else {  // Y is a scalar
    auto x_dim = x.dims();
    auto y_dim = y.dims();

    axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
479
    trim_trailing_singular_dims(&y_dim);
480 481
    axis = (y_dim.size() == 0) ? x_dim.size() : axis;

Y
Yu Yang 已提交
482
    int pre, n, post;
483
    get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
Y
Yu Yang 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
    if (post == 1) {
      int h = pre;
      int w = n;
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast1CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast1CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
            dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    } else {
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast2CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
            dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast2CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
            post, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    }
  }
521
}
Y
Yu Yang 已提交
522

Q
QI JUN 已提交
523
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
524
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
525 526 527 528 529 530
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
531
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
550
  trim_trailing_singular_dims(&y_dims);
551
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
552 553

  int pre, n, post;
554
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
555 556 557 558 559 560 561 562 563 564 565

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
566

567 568
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
569 570
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
571
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
572
                          framework::Tensor* z) {
573
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
574
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588

  auto x_dims = x->dims();
  auto y_dims = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
                    "Rank of first input must >= rank of second input.");

  if (x_dims == y_dims) {
    functor.Run();
    return;
  }

  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");
589
  trim_trailing_singular_dims(&y_dims);
590
  axis = (y_dims.size() == 0) ? x_dims.size() : axis;
F
fengjiayi 已提交
591 592

  int pre, n, post;
593
  get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
F
fengjiayi 已提交
594 595 596 597 598 599 600 601 602
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

603 604
}  // namespace operators
}  // namespace paddle