elementwise_add_op.cu 8.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
16

G
gongweibao 已提交
17
namespace ops = paddle::operators;
K
Kexin Zhao 已提交
18
namespace plat = paddle::platform;
G
gongweibao 已提交
19

20 21 22 23
namespace paddle {
namespace operators {

template <typename T>
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
static __global__ void SimpleElemwiseAddGradCUDAKernel(
    const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  int loop = size / vec_size;
  int remainder = size % vec_size;
  const float4* dout_vec = reinterpret_cast<const float4*>(dout);
  float4* dx_vec = reinterpret_cast<float4*>(dx);
  float4* dy_vec = reinterpret_cast<float4*>(dy);
  float4 tmp_loop;

  for (int i = tid; i < loop; i += stride) {
    tmp_loop = dout_vec[i];
    dx_vec[i] = tmp_loop;
    dy_vec[i] = tmp_loop;
  }
40

41 42 43 44 45 46 47 48 49
  if (tid == loop && remainder != 0) {
    T tmp_rem;
    while (remainder) {
      int idx = size - remainder;
      remainder--;
      tmp_rem = dout[idx];
      dx[idx] = tmp_rem;
      dy[idx] = tmp_rem;
    }
50 51 52
  }
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_add_grad(const framework::ExecutionContext& ctx,
                             const framework::Tensor* x,
                             const framework::Tensor* y,
                             const framework::Tensor* out,
                             const framework::Tensor* dout,
                             framework::Tensor* dx, framework::Tensor* dy) {
  int axis = ctx.Attr<int>("axis");
  auto* dout_data = dout->data<T>();

  // dx
  if (dx != nullptr) {
    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    if (dx->dims() == dout->dims()) {
      if (dx_data != dout_data) {
        framework::TensorCopy(
            *dout, ctx.GetPlace(),
            ctx.template device_context<platform::DeviceContext>(), dx);
      }
    } else {
      // For inplace strategy, dx will be stored in addr of dout, which makes
      // the result of dy wrong.
      if (dx->IsSharedBufferWith(*dout)) {
        dx->clear();
        dx->mutable_data<T>(x->dims(), ctx.GetPlace());
      }
      std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
      gpuStream_t stream = ctx.cuda_device_context().stream();
83 84
      TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
          *dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
85 86 87 88 89 90 91 92 93 94 95 96 97 98
    }
  }
  // dy
  if (dy != nullptr) {
    auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
    if (dy->dims() == dout->dims()) {
      if (dy_data != dout_data) {
        framework::TensorCopy(
            *dout, ctx.GetPlace(),
            ctx.template device_context<platform::DeviceContext>(), dy);
      }
    } else {
      std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
      gpuStream_t stream = ctx.cuda_device_context().stream();
99 100
      TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
          *dout, dy, kps::IdentityFunctor<T>(), reduce_dims, stream);
101 102 103 104
    }
  }
}

105 106 107
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
108 109 110 111 112
elementwise_add_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
  auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
  auto* dout_data = dout->data<T>();
  if (dx_data == dout_data && dy_data != dout_data) {
    VLOG(4) << "Special case when dx_data is the same as dout_data, "
               "only need copy dout to dy";
    framework::TensorCopy(
        *dout, ctx.GetPlace(),
        ctx.template device_context<platform::DeviceContext>(), dy);
  } else if (dx_data != dout_data && dy_data == dout_data) {
    VLOG(4) << "Special case when dy_data is the same as dout_data, "
               "only need copy dout to dx";
    framework::TensorCopy(
        *dout, ctx.GetPlace(),
        ctx.template device_context<platform::DeviceContext>(), dx);
  } else if (dx_data != dout_data && dy_data != dout_data) {
    auto size = x->numel();
    int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
131
    dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
132
    dim3 grid_size =
133 134
        dim3(((size + vec_size - 1) / vec_size + ELEMENTWISE_BLOCK_SIZE - 1) /
                 ELEMENTWISE_BLOCK_SIZE,
135 136 137 138 139 140 141 142 143 144 145
             1);
    SimpleElemwiseAddGradCUDAKernel<
        T><<<grid_size, block_size, 0,
             ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
        dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
        dy->mutable_data<T>(ctx.GetPlace()));
  } else {
    VLOG(4) << "Special case when dy_data is the same as dout_data, "
               "and dx_data is the same as dout_data, do not need "
               "any operator";
  }
146 147 148 149
}

}  // namespace operators
}  // namespace paddle
Q
QI JUN 已提交
150
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
151 152 153
    elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
K
Kexin Zhao 已提交
154
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
155
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
156 157
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);
Q
QI JUN 已提交
158
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
159
    elementwise_add_grad,
K
Kexin Zhao 已提交
160 161 162
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
C
chengduo 已提交
163
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
164
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
165 166 167 168
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
                                  plat::complex<float>>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext,
                                  plat::complex<double>>);
169 170 171 172 173
REGISTER_OP_CUDA_KERNEL(
    elementwise_add_grad_grad,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
174
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
175
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
176
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
177
                                        plat::complex<float>>,
178
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
179
                                        plat::complex<double>>);
180 181 182 183 184 185 186 187 188 189 190
REGISTER_OP_CUDA_KERNEL(
    elementwise_add_triple_grad,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, int64_t>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext,
                                        plat::complex<float>>,
    ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext,
                                        plat::complex<double>>);
191 192 193 194 195 196

REGISTER_OP_CUDA_KERNEL(
    grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
197
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
198 199
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex<double>>);