elementwise_op_function.h 21.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
20

C
chengduoZH 已提交
21 22
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
Y
Yu Yang 已提交
23
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
C
chengduoZH 已提交
24 25
#endif

Y
Yi Wang 已提交
26
#include "paddle/fluid/operators/math/math_function.h"
Y
Yu Yang 已提交
27
#include "paddle/fluid/platform/for_range.h"
28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace operators {

/*
 * Out = X ⊙ Y
 * If Y's shape does not match X' shape, they will be reshaped.
 * For example:
 * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
 *    pre=2, n=3*4, post=5
C
chengduo 已提交
38
 *    x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
39 40
 * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
 *    pre=2*3, n=4*5, post=1
C
chengduo 已提交
41
 *    x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
 */
inline void get_mid_dims(const framework::DDim& x_dims,
                         const framework::DDim& y_dims, const int axis,
                         int& pre, int& n, int& post) {
  pre = 1;
  n = 1;
  post = 1;
  for (int i = 0; i < axis; ++i) {
    pre *= x_dims[i];
  }

  for (int i = 0; i < y_dims.size(); ++i) {
    PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
                      "Broadcast dimension mismatch.");
    n *= y_dims[i];
  }

  for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
    post *= x_dims[i];
  }
}

Q
QI JUN 已提交
64
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
65
class RowwiseTransformIterator;
Q
QI JUN 已提交
66
template <typename T, typename DeviceContext>
C
chengduoZH 已提交
67
class MidWiseTransformIterator;
C
chengduoZH 已提交
68 69

template <typename T>
Q
QI JUN 已提交
70
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
71
 public:
C
chengduoZH 已提交
72 73
  RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}

Q
QI JUN 已提交
74
  RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
75
    ++i_;
C
chengduoZH 已提交
76 77 78
    if (UNLIKELY(i_ == n_)) {
      i_ = 0;
    }
C
chengduoZH 已提交
79 80 81
    return *this;
  }

Q
QI JUN 已提交
82 83
  bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
84
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
85 86
  }

Q
QI JUN 已提交
87 88
  bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
89
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
90 91 92 93
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
94
 private:
C
chengduoZH 已提交
95 96
  const T* ptr_;
  int i_;
C
chengduoZH 已提交
97
  int64_t n_;
C
chengduoZH 已提交
98 99 100
};

template <typename T>
Q
QI JUN 已提交
101
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
C
chengduoZH 已提交
102
 public:
C
chengduoZH 已提交
103 104 105
  MidWiseTransformIterator(const T* ptr, int n, int post)
      : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}

Q
QI JUN 已提交
106
  MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
C
chengduoZH 已提交
107
    ++j_;
C
chengduoZH 已提交
108 109
    if (UNLIKELY(j_ == post_)) {
      ++i_;
C
refine  
chengduoZH 已提交
110
      j_ = 0;
C
chengduoZH 已提交
111 112 113
      if (UNLIKELY(i_ == n_)) {
        i_ = 0;
      }
C
chengduoZH 已提交
114
    }
C
chengduoZH 已提交
115 116 117
    return *this;
  }

Q
QI JUN 已提交
118 119
  bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
120
    return (ptr_ + i_) == &(*rhs);
C
chengduoZH 已提交
121 122
  }

Q
QI JUN 已提交
123 124
  bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
                      rhs) const {
C
chengduoZH 已提交
125
    return (ptr_ + i_) != &(*rhs);
C
chengduoZH 已提交
126 127 128 129
  }

  const T& operator*() { return ptr_[i_]; }

C
chengduoZH 已提交
130
 private:
C
chengduoZH 已提交
131
  const T* ptr_;
C
refine  
chengduoZH 已提交
132
  int64_t i_;
C
chengduoZH 已提交
133 134
  int64_t j_;
  int64_t n_;
C
refine  
chengduoZH 已提交
135
  int64_t post_;
C
chengduoZH 已提交
136 137
};

C
chengduoZH 已提交
138 139
#ifdef __NVCC__
template <typename T>
Q
QI JUN 已提交
140
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
141
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
142
          RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
143 144
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
145
      RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
146
      super_t;
C
chengduoZH 已提交
147
  HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
C
chengduoZH 已提交
148 149 150 151 152 153
      : super_t(x), begin_(x), n_(n){};
  friend class thrust::iterator_core_access;

 private:
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
154
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
155 156 157 158 159
    return *(begin_ + (this->base() - begin_) % n_);
  }
};

template <typename T>
Q
QI JUN 已提交
160
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
C
chengduoZH 已提交
161
    : public thrust::iterator_adaptor<
Q
QI JUN 已提交
162
          MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
C
chengduoZH 已提交
163 164
 public:
  typedef thrust::iterator_adaptor<
Q
QI JUN 已提交
165
      MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
C
chengduoZH 已提交
166
      super_t;
C
chengduoZH 已提交
167
  HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
C
chengduoZH 已提交
168 169 170 171 172 173 174
      : super_t(x), begin_(x), n_(n), post_(post){};
  friend class thrust::iterator_core_access;

 private:
  unsigned int post_;
  unsigned int n_;
  const T* begin_;
C
chengduoZH 已提交
175
  HOSTDEVICE typename super_t::reference dereference() const {
C
chengduoZH 已提交
176 177 178 179 180
    return *(begin_ + (((this->base() - begin_) / post_) % n_));
  }
};
#endif

181 182
template <typename Functor, typename T, typename DeviceContext,
          typename OutType = T>
C
chengduoZH 已提交
183 184
class TransformFunctor {
 public:
C
chengduoZH 已提交
185
  TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
Q
QI JUN 已提交
186
                   framework::Tensor* z, const DeviceContext& ctx, Functor func)
C
chengduoZH 已提交
187 188
      : x_(x->data<T>()),
        y_(y->data<T>()),
189
        z_(z->mutable_data<OutType>(ctx.GetPlace())),
C
chengduoZH 已提交
190 191 192 193 194
        nx_(x->numel()),
        ctx_(ctx),
        func_(func) {}

  inline void Run() const {
Q
QI JUN 已提交
195
    platform::Transform<DeviceContext> trans;
C
chengduoZH 已提交
196
    trans(ctx_, x_, x_ + nx_, y_, z_, func_);
C
chengduoZH 已提交
197 198 199
  }

  inline void RunRowWise(int n, int pre) const {
Q
QI JUN 已提交
200 201 202
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
          z_, func_);
C
chengduoZH 已提交
203 204 205
  }

  inline void RunMidWise(int n, int pre, int post) const {
Q
QI JUN 已提交
206 207 208
    platform::Transform<DeviceContext> trans;
    trans(ctx_, x_, x_ + nx_,
          MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
C
chengduoZH 已提交
209 210
  }

C
chengduoZH 已提交
211
 private:
C
chengduoZH 已提交
212 213
  const T* x_;
  const T* y_;
214
  OutType* z_;
C
chengduoZH 已提交
215
  int64_t nx_;
Q
QI JUN 已提交
216
  const DeviceContext& ctx_;
C
chengduoZH 已提交
217 218 219
  Functor func_;
};

220 221
#define EIGEN_FUNCTOR(name, eigen_op)                                          \
  struct Eigen##name##Functor {                                                \
Q
QI JUN 已提交
222
    template <typename DeviceContext, typename T>                              \
223 224 225 226 227 228
    inline void Run(const framework::Tensor* x, const framework::Tensor* y,    \
                    framework::Tensor* z,                                      \
                    const framework::ExecutionContext& ctx) {                  \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
Q
QI JUN 已提交
229 230 231
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_e);                                                  \
232
    }                                                                          \
Q
QI JUN 已提交
233
    template <typename DeviceContext, typename T>                              \
234 235 236 237 238 239 240 241 242 243
    inline void RunBroadCast(const framework::Tensor* x,                       \
                             const framework::Tensor* y, framework::Tensor* z, \
                             const framework::ExecutionContext& ctx, int pre,  \
                             int n) {                                          \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))                  \
                         .broadcast(Eigen::DSizes<int, 2>(pre, 1))             \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
244 245 246
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
247
    }                                                                          \
Q
QI JUN 已提交
248
    template <typename DeviceContext, typename T>                              \
249 250 251 252 253 254 255 256 257 258 259
    inline void RunBroadCast2(const framework::Tensor* x,                      \
                              const framework::Tensor* y,                      \
                              framework::Tensor* z,                            \
                              const framework::ExecutionContext& ctx, int pre, \
                              int n, int post) {                               \
      auto x_e = framework::EigenVector<T>::Flatten(*x);                       \
      auto y_e = framework::EigenVector<T>::Flatten(*y);                       \
      auto z_e = framework::EigenVector<T>::Flatten(*z);                       \
      auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))               \
                         .broadcast(Eigen::DSizes<int, 3>(pre, 1, post))       \
                         .reshape(Eigen::DSizes<int, 1>(x_e.size()));          \
Q
QI JUN 已提交
260 261 262
      z_e.device(                                                              \
          *ctx.template device_context<DeviceContext>().eigen_device()) =      \
          eigen_op(x_e, y_bcast);                                              \
263 264 265
    }                                                                          \
  }

Q
QI JUN 已提交
266
template <class functor, typename DeviceContext, typename T>
267 268 269 270 271 272 273 274 275 276 277
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
  using Tensor = framework::Tensor;

  auto* x = ctx.Input<Tensor>("X");
  auto* y = ctx.Input<Tensor>("Y");
  auto* z = ctx.Output<Tensor>("Out");
  z->mutable_data<T>(ctx.GetPlace());

  auto x_dims = x->dims();
  auto y_dims = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
278
                    "Rank of first input must >= rank of second input.");
279

Q
qijun 已提交
280
  if (x_dims == y_dims) {
281
    functor f;
Q
QI JUN 已提交
282
    f.template Run<DeviceContext, T>(x, y, z, ctx);
283 284 285 286 287 288 289 290 291 292 293 294
    return;
  }

  int axis = ctx.Attr<int>("axis");
  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");

  int pre, n, post;
  get_mid_dims(x_dims, y_dims, axis, pre, n, post);
  if (post == 1) {
    functor f;
Q
QI JUN 已提交
295
    f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
296 297 298
    return;
  } else {
    functor f;
Q
QI JUN 已提交
299
    f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
    return;
  }
}

#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);

#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);

#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);

#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);

Y
Yu Yang 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
  const T* x_;
  const T* y_;
  const T* out_;
  const T* dout_;

  HOSTDEVICE void operator()(size_t i) {
    if (dx_ != nullptr) {
      dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
    if (dy_ != nullptr) {
      dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
    }
  }

  DX_OP dx_op_;
  DY_OP dy_op_;
  T* dx_;
  T* dy_;
};

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int h, int w, DX_OP dx_op,
                                      DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < h; ++i) {
    for (int j = 0; j < w; ++j) {
      int x_offset = i * w + j;
      if (dx != nullptr) {
        dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
      }
      if (dy != nullptr) {
        T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        if (i == 0) {
          dy[j] = tmp;
        } else {
          dy[j] += tmp;
        }
      }
    }
  }
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int h, int w,
    DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  extern __shared__ char shm_buffer[];
  T* shm = reinterpret_cast<T*>(shm_buffer);

  int j = blockIdx.x;
  int i = threadIdx.x;
  int tid = threadIdx.x;
  shm[tid] = 0;

  do {
    int x_offset = i * w + j;
    if (dx) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    if (dy) {
      shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }
    i += ELEMWISE_MAX_BLOCK_DIM;
  } while (i < h);

  if (dy) {
    __syncthreads();

    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;

    // Sum, could be optimized
    if (threadIdx.x == 0) {
      for (int k = 1; k < h; ++k) {
        shm[0] += shm[k];
      }
      dy[j] = shm[0];
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int h, int w, DX_OP dx_op, DY_OP dy_op,
                                       T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
  int gird_size = w;
  int shared_mem_size = block_size * sizeof(T);
  ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, shared_mem_size,
                                     stream>>>(x, y, out, dout, h, w, dx_op,
                                               dy_op, dx, dy);
}

#endif

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
                                      const T* dout, int pre, int n, int post,
                                      DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  for (int i = 0; i < pre; ++i) {
    for (int j = 0; j < n; ++j) {
      for (int k = 0; k < post; ++k) {
        int x_offset = i * n * post + j * post + k;
        if (dx != nullptr) {
          dx[x_offset] =
              dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
        }
        if (dy != nullptr) {
          T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
          if (i == 0 && k == 0) {
            dy[j] = tmp;
          } else {
            dy[j] += tmp;
          }
        }
      }
    }
  }
}

#ifdef __NVCC__

template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
    const T* x, const T* y, const T* out, const T* dout, int pre, int n,
    int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
  int tid = threadIdx.x;
  int j = blockIdx.x;

  extern __shared__ char shm_buffer[];
  T* shm = reinterpret_cast<T*>(shm_buffer);
  shm[tid] = 0;
  int ttid = tid;

  while (true) {
    int i = ttid / post;
    int k = ttid % post;
    if (i >= pre) break;

    int x_offset = i * n * post + j * post + k;

    if (dx != nullptr) {
      dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    if (dy != nullptr) {
      shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
    }

    ttid += ELEMWISE_MAX_BLOCK_DIM;
  }

  if (dy) {
    __syncthreads();
    int h = pre * post;
    h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;

    // Sum, could be optimized
    if (tid == 0) {
      for (int i = 1; i < h; ++i) {
        shm[0] += shm[i];
      }
      dy[j] = shm[0];
    }
  }
}

template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
                                       const T* y, const T* out, const T* dout,
                                       int pre, int n, int post, DX_OP dx_op,
                                       DY_OP dy_op, T* dx, T* dy) {
  int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
  int gird_size = n;
  int shared_mem_size = block_size * sizeof(T);
  ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, shared_mem_size,
                                     stream>>>(x, y, out, dout, pre, n, post,
                                               dx_op, dy_op, dx, dy);
}

#endif

template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                         const framework::Tensor& x, const framework::Tensor& y,
                         const framework::Tensor& out,
                         const framework::Tensor& dout, int axis,
                         framework::Tensor* dx, framework::Tensor* dy,
                         DX_OP dx_op, DY_OP dy_op) {
  if (x.dims() == y.dims()) {
    size_t N = static_cast<size_t>(framework::product(x.dims()));
    platform::ForRange<DeviceContext> for_range(
        ctx.template device_context<DeviceContext>(), N);
    for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
        x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
        dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
        dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
  } else {  // Y is a scalar
    auto x_dim = x.dims();
    auto y_dim = y.dims();

    if (y_dim.size() == 1 && y_dim[0] == 1) {
      // y is a scalar
      auto extended_dims = framework::vectorize(x_dim);
      extended_dims.push_back(1);
      x_dim = framework::make_ddim(extended_dims);
    }

    axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
    int pre, n, post;
    get_mid_dims(x_dim, y_dim, axis, pre, n, post);
    if (post == 1) {
      int h = pre;
      int w = n;
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast1CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast1CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
            dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    } else {
      if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
        ElemwiseGradBroadcast2CUDA(
            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
            y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
            dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
      } else {
        ElemwiseGradBroadcast2CPU(
            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
            post, dx_op, dy_op,
            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
      }
    }
  }
};

Q
QI JUN 已提交
568
template <typename DeviceContext, typename T, typename functor,
F
fengjiayi 已提交
569
          typename broadcastfunctor, typename broadcast2functor>
C
chengduoZH 已提交
570 571 572 573 574 575
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
                            const framework::Tensor* x,
                            const framework::Tensor* y,
                            const framework::Tensor* out,
                            const framework::Tensor* dout, int axis,
                            framework::Tensor* dx, framework::Tensor* dy) {
Q
QI JUN 已提交
576
  auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593

  auto x_dims = x->dims();
  auto y_dims = y->dims();

  if (dx) {
    dx->mutable_data<T>(ctx.GetPlace());
  }
  if (dy) {
    dy->mutable_data<T>(ctx.GetPlace());
  }

  if (x_dims == y_dims) {
    functor f;
    f(place, x, y, out, dx, dy, dout);
    return;
  }

594 595 596 597 598 599 600
  if (y_dims.size() == 1 && y_dims[0] == 1) {
    // y is a scalar
    auto extended_dims = framework::vectorize(x_dims);
    extended_dims.push_back(1);
    x_dims = framework::make_ddim(extended_dims);
  }

601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);

  int pre, n, post;
  get_mid_dims(x_dims, y_dims, axis, pre, n, post);

  if (post == 1) {
    broadcastfunctor f;
    f(place, x, y, out, dx, dy, dout, pre, n);
    return;
  } else {
    broadcast2functor f;
    f(place, x, y, out, dx, dy, dout, pre, n, post);
    return;
  }
}
F
fengjiayi 已提交
616

617 618
template <typename Functor, typename DeviceContext, typename T,
          typename OutType = T>
C
chengduoZH 已提交
619 620
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
                          const framework::Tensor* x,
C
chengduoZH 已提交
621
                          const framework::Tensor* y, int axis, Functor func,
C
chengduoZH 已提交
622
                          framework::Tensor* z) {
623
  TransformFunctor<Functor, T, DeviceContext, OutType> functor(
C
chengduoZH 已提交
624
      x, y, z, ctx.template device_context<DeviceContext>(), func);
F
fengjiayi 已提交
625 626 627 628 629 630 631 632 633 634 635

  auto x_dims = x->dims();
  auto y_dims = y->dims();
  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
                    "Rank of first input must >= rank of second input.");

  if (x_dims == y_dims) {
    functor.Run();
    return;
  }

636 637 638 639 640 641 642
  if (y_dims.size() == 1 && y_dims[0] == 1) {
    // y is a scalar
    auto extended_dims = framework::vectorize(x_dims);
    extended_dims.push_back(1);
    x_dims = framework::make_ddim(extended_dims);
  }

F
fengjiayi 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
  PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                 "Axis should be in range [0, x_dims)");

  int pre, n, post;
  get_mid_dims(x_dims, y_dims, axis, pre, n, post);
  if (post == 1) {
    functor.RunRowWise(n, pre);
    return;
  } else {
    functor.RunMidWise(n, pre, post);
    return;
  }
}

658 659
}  // namespace operators
}  // namespace paddle