elementwise_sub_op.cu 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
16
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
G
gongweibao 已提交
17 18

namespace ops = paddle::operators;
19 20 21 22 23 24 25 26 27 28 29 30
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
                                                       int64_t size, T* dx,
                                                       T* dy) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
31 32 33
    if (dx != nullptr) {
      dx[col] = dout[col];
    }
34 35 36 37 38
    dy[col] = -dout[col];
    col += blockDim.x * gridDim.x;
  }
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
                             const framework::Tensor* x,
                             const framework::Tensor* y,
                             const framework::Tensor* out,
                             const framework::Tensor* dout,
                             framework::Tensor* dx, framework::Tensor* dy) {
  int axis = ctx.Attr<int>("axis");
  auto* dout_data = dout->data<T>();
  // dx
  if (dx != nullptr) {
    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
    if (dx->dims() == dout->dims()) {
      if (dx_data != dout_data) {
        framework::TensorCopy(
            *dout, ctx.GetPlace(),
            ctx.template device_context<platform::DeviceContext>(), dx);
      }
    } else {
      // For inplace strategy, dx will be stored in addr of dout, which makes
      // the result of dy wrong.
      if (dx->IsSharedBufferWith(*dout)) {
        dx->clear();
        dx->mutable_data<T>(x->dims(), ctx.GetPlace());
      }
      std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
      gpuStream_t stream = ctx.cuda_device_context().stream();
68 69
      TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
          *dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    }
  }
  // dy
  if (dy != nullptr) {
    auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
    if (dy->dims() == dout->dims()) {
      if (dy_data != dout_data) {
        dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
        auto size = dy->numel();
        dim3 grid_size = dim3(
            (size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
        SimpleElemwiseSubGradCUDAKernel<T><<<
            grid_size, block_size, 0,
            ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
            dout->data<T>(), size, nullptr,
            dy->mutable_data<T>(ctx.GetPlace()));
      }
    } else {
      std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
      gpuStream_t stream = ctx.cuda_device_context().stream();
90 91
      TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
          *dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
92 93 94 95
    }
  }
}

96 97 98 99 100 101 102 103
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
104
  dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
105
  auto size = x->numel();
106
  dim3 grid_size =
107
      dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
108
  SimpleElemwiseSubGradCUDAKernel<
109
      T><<<grid_size, block_size, 0,
110 111 112 113 114 115 116
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
      dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
      dy->mutable_data<T>(ctx.GetPlace()));
}

}  // namespace operators
}  // namespace paddle
G
gongweibao 已提交
117

Q
QI JUN 已提交
118
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
119
    elementwise_sub,
Q
QI JUN 已提交
120
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
121 122
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
                              paddle::platform::float16>,
Q
QI JUN 已提交
123 124
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
125 126
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
127
                              paddle::platform::complex<float>>,
128
    ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
129
                              paddle::platform::complex<double>>);
Q
QI JUN 已提交
130
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
131
    elementwise_sub_grad,
Q
QI JUN 已提交
132
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
133 134
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
                                  paddle::platform::float16>,
Q
QI JUN 已提交
135 136
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
137 138
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
139
                                  paddle::platform::complex<float>>,
Q
QI JUN 已提交
140
    ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
141
                                  paddle::platform::complex<double>>);
142 143 144 145 146 147 148 149 150
REGISTER_OP_CUDA_KERNEL(
    elementwise_sub_grad_grad,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        float>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        double>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        int>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
151 152
                                        int64_t>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
153
                                        paddle::platform::complex<float>>,
154
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
155
                                        paddle::platform::complex<double>>);