elementwise_add_op.cu 7.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
Wu Yi 已提交
14
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
15
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
16 17
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
K
Kexin Zhao 已提交
18
#include "paddle/fluid/platform/float16.h"
G
gongweibao 已提交
19 20

namespace ops = paddle::operators;
K
Kexin Zhao 已提交
21
namespace plat = paddle::platform;
G
gongweibao 已提交
22

23 24 25 26
namespace paddle {
namespace operators {

template <typename T>
27 28 29 30
struct SameDimsElemwiseAdd<
    platform::CUDADeviceContext, T,
    typename std::enable_if<!std::is_same<T, platform::float16>::value &&
                            !std::is_same<T, float>::value>::type> {
31 32 33 34 35 36 37 38 39 40 41
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor* x, const framework::Tensor* y,
                  framework::Tensor* z) {
    AddRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
                                                              x->numel());
    for_range(functor);
  }
};

42 43 44 45 46
template <typename T>
struct SameDimsElemwiseAdd<
    platform::CUDADeviceContext, T,
    typename std::enable_if<std::is_same<T, platform::float16>::value ||
                            std::is_same<T, float>::value>::type> {
47 48 49 50
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor* x, const framework::Tensor* y,
                  framework::Tensor* z) {
    auto size = x->numel();
51 52 53 54 55
    int vec_size = sizeof(float4) / sizeof(T);
    dim3 grid_size =
        dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
                 PADDLE_CUDA_THREAD_SIZE,
             1);
56
    dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    if (std::is_same<T, float>::value) {
      SameDimsElemwiseAddCUDAKernel<<<
          grid_size, block_size, 0,
          ctx.template device_context<platform::CUDADeviceContext>()
              .stream()>>>(x->data<float>(), y->data<float>(), z->data<float>(),
                           size);
    } else {
      const half* x2 =
          reinterpret_cast<const half*>(x->data<platform::float16>());
      const half* y2 =
          reinterpret_cast<const half*>(y->data<platform::float16>());
      half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
      SameDimsElemwiseAddCUDAKernel<<<
          grid_size, block_size, 0,
          ctx.template device_context<platform::CUDADeviceContext>()
              .stream()>>>(x2, y2, z2, size);
    }
74 75 76 77
  }
};

template <typename T>
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
static __global__ void SimpleElemwiseAddGradCUDAKernel(
    const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  int loop = size / vec_size;
  int remainder = size % vec_size;
  const float4* dout_vec = reinterpret_cast<const float4*>(dout);
  float4* dx_vec = reinterpret_cast<float4*>(dx);
  float4* dy_vec = reinterpret_cast<float4*>(dy);
  float4 tmp_loop;

  for (int i = tid; i < loop; i += stride) {
    tmp_loop = dout_vec[i];
    dx_vec[i] = tmp_loop;
    dy_vec[i] = tmp_loop;
  }
94

95 96 97 98 99 100 101 102 103
  if (tid == loop && remainder != 0) {
    T tmp_rem;
    while (remainder) {
      int idx = size - remainder;
      remainder--;
      tmp_rem = dout[idx];
      dx[idx] = tmp_rem;
      dy[idx] = tmp_rem;
    }
104 105 106 107 108 109
  }
}

template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
110 111 112 113 114
elementwise_add_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
115
  auto size = x->numel();
116 117
  int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
  dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
118
  dim3 grid_size =
119 120 121
      dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
               PADDLE_CUDA_THREAD_SIZE,
           1);
122
  SimpleElemwiseAddGradCUDAKernel<
123
      T><<<grid_size, block_size, 0,
124
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
125
      dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
126 127 128 129 130
      dy->mutable_data<T>(ctx.GetPlace()));
}

}  // namespace operators
}  // namespace paddle
Q
QI JUN 已提交
131
REGISTER_OP_CUDA_KERNEL(
K
Kexin Zhao 已提交
132 133 134
    elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
K
Kexin Zhao 已提交
135
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
136 137 138
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
Q
QI JUN 已提交
139
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
140
    elementwise_add_grad,
K
Kexin Zhao 已提交
141 142 143
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
C
chengduo 已提交
144
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
145 146 147
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
    ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
148 149 150 151 152
REGISTER_OP_CUDA_KERNEL(
    elementwise_add_grad_grad,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
153
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
154
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
155
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
156 157 158
                                        plat::complex64>,
    ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
                                        plat::complex128>);
159 160 161 162 163 164

REGISTER_OP_CUDA_KERNEL(
    grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
165 166 167
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
    ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);