elementwise_sub_op.cu 6.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
W
Wu Yi 已提交
16
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
17
#include "paddle/fluid/platform/complex.h"
18
#include "paddle/fluid/platform/float16.h"
G
gongweibao 已提交
19 20

namespace ops = paddle::operators;
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
struct SameDimsElemwiseSub<platform::CUDADeviceContext, T> {
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor* x, const framework::Tensor* y,
                  framework::Tensor* z) {
    SubRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
                                                              x->numel());
    for_range(functor);
  }
};

template <>
struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
  void operator()(const framework::ExecutionContext& ctx,
                  const framework::Tensor* x, const framework::Tensor* y,
                  framework::Tensor* z) {
    auto size = x->numel();
45
    dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
46 47
                              PADDLE_CUDA_THREAD_SIZE,
                          1);
48 49 50 51 52 53 54
    dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
    const half* x2 =
        reinterpret_cast<const half*>(x->data<platform::float16>());
    const half* y2 =
        reinterpret_cast<const half*>(y->data<platform::float16>());
    half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
    SameDimsElemwiseSubCUDAKernel<<<
55
        grid_size, block_size, 0,
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
        x2, y2, z2, size);
  }
};

template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
                                                       int64_t size, T* dx,
                                                       T* dy) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
    dx[col] = dout[col];
    dy[col] = -dout[col];
    col += blockDim.x * gridDim.x;
  }
}

template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
  dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
  auto size = x->numel();
84
  dim3 grid_size =
85 86
      dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
  SimpleElemwiseSubGradCUDAKernel<
87
      T><<<grid_size, block_size, 0,
88 89 90 91 92 93 94
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
      dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
      dy->mutable_data<T>(ctx.GetPlace()));
}

}  // namespace operators
}  // namespace paddle
G
gongweibao 已提交
95

Q
QI JUN 已提交
96
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
97
    elementwise_sub,
Q
QI JUN 已提交
98
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
99 100
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
                              paddle::platform::float16>,
Q
QI JUN 已提交
101 102
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
103 104
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
105
                              paddle::platform::complex<float>>,
106
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
107
                              paddle::platform::complex<double>>);
Q
QI JUN 已提交
108
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
109
    elementwise_sub_grad,
Q
QI JUN 已提交
110
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
111 112
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
                                  paddle::platform::float16>,
Q
QI JUN 已提交
113 114
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
115 116
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
117
                                  paddle::platform::complex<float>>,
Q
QI JUN 已提交
118
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
119
                                  paddle::platform::complex<double>>);
120 121 122 123 124 125 126 127 128
REGISTER_OP_CUDA_KERNEL(
    elementwise_sub_grad_grad,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        float>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        double>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        int>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
129 130
                                        int64_t>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
131
                                        paddle::platform::complex<float>>,
132
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
133
                                        paddle::platform::complex<double>>);