elementwise_add_op.cu 7.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
Wu Yi 已提交
14
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
15
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
16
#include "paddle/fluid/platform/complex.h"
K
Kexin Zhao 已提交
17
#include "paddle/fluid/platform/float16.h"
G
gongweibao 已提交
18 19

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
20
namespace plat = paddle::platform;
G
gongweibao 已提交
21

22 23 24
namespace paddle {
namespace operators {

25 26 27 28 29 30
/*
   input: an array;
   return: the result of the math functor
   1. For Unary Op, the length of input array is 1,
      e.g. Relu: return args[0] > 0 ? args[0] : 0;
   2. For Binary Op, the length of input array is 2,
31
      e.g. Add: return args[0] expr args[1];
32
*/
33
template <typename T>
34
struct CudaAddFunctor {
35
  inline HOSTDEVICE T operator()(const T* args) const {
36 37
    return args[0] + args[1];
  }
38 39
};

40
template <typename T>
41 42 43 44
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
45 46
    std::vector<const framework::Tensor*> ins;
    std::vector<framework::Tensor*> outs;
47 48 49 50
    const auto& cuda_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();

    int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
51
    LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
52
        cuda_ctx, ins, &outs, axis, CudaAddFunctor<T>());
53 54 55
  }
};

56
template <typename T>
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
static __global__ void SimpleElemwiseAddGradCUDAKernel(
    const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  int loop = size / vec_size;
  int remainder = size % vec_size;
  const float4* dout_vec = reinterpret_cast<const float4*>(dout);
  float4* dx_vec = reinterpret_cast<float4*>(dx);
  float4* dy_vec = reinterpret_cast<float4*>(dy);
  float4 tmp_loop;

  for (int i = tid; i < loop; i += stride) {
    tmp_loop = dout_vec[i];
    dx_vec[i] = tmp_loop;
    dy_vec[i] = tmp_loop;
  }
73

74 75 76 77 78 79 80 81 82
  if (tid == loop && remainder != 0) {
    T tmp_rem;
    while (remainder) {
      int idx = size - remainder;
      remainder--;
      tmp_rem = dout[idx];
      dx[idx] = tmp_rem;
      dy[idx] = tmp_rem;
    }
83 84 85 86 87 88
  }
}

template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
89 90 91 92 93
elementwise_add_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
  auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
  auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
  auto* dout_data = dout->data<T>();
  if (dx_data == dout_data && dy_data != dout_data) {
    VLOG(4) << "Special case when dx_data is the same as dout_data, "
               "only need copy dout to dy";
    framework::TensorCopy(
        *dout, ctx.GetPlace(),
        ctx.template device_context<platform::DeviceContext>(), dy);
  } else if (dx_data != dout_data && dy_data == dout_data) {
    VLOG(4) << "Special case when dy_data is the same as dout_data, "
               "only need copy dout to dx";
    framework::TensorCopy(
        *dout, ctx.GetPlace(),
        ctx.template device_context<platform::DeviceContext>(), dx);
  } else if (dx_data != dout_data && dy_data != dout_data) {
    auto size = x->numel();
    int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
    dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
    dim3 grid_size =
        dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
                 PADDLE_CUDA_THREAD_SIZE,
             1);
    SimpleElemwiseAddGradCUDAKernel<
        T><<<grid_size, block_size, 0,
             ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
        dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
        dy->mutable_data<T>(ctx.GetPlace()));
  } else {
    VLOG(4) << "Special case when dy_data is the same as dout_data, "
               "and dx_data is the same as dout_data, do not need "
               "any operator";
  }
127 128 129 130
}

}  // namespace operators
}  // namespace paddle
Q
QI JUN 已提交
131
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
132 133 134
    elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
K
Kexin Zhao 已提交
135
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
136
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
137 138
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
Q
QI JUN 已提交
139
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
140
    elementwise_add_grad,
K
Kexin Zhao 已提交
141 142 143
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
C
chengduo 已提交
144
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
145
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
146 147 148 149
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
                                  plat::complex<float>>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
                                  plat::complex<double>>);
150 151 152 153 154
REGISTER_OP_CUDA_KERNEL(
    elementwise_add_grad_grad,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
155
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
156
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
157
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
158
                                        plat::complex<float>>,
159
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
160
                                        plat::complex<double>>);
161 162 163 164 165 166

REGISTER_OP_CUDA_KERNEL(
    grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
167
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
168 169
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);