elementwise_div_op.cu 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
16
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
17
#include "paddle/fluid/platform/complex.h"
W
Wu Yi 已提交
18
#include "paddle/fluid/platform/float16.h"
G
gongweibao 已提交
19 20

namespace ops = paddle::operators;
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
                                                       const T* out,
                                                       const T* dout,
                                                       int64_t size, T* dx,
                                                       T* dy) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
    T o = dout[col];
    dx[col] = o / y[col];
    dy[col] = -o * out[col] / y[col];
    col += blockDim.x * gridDim.x;
  }
}

42
template <>
43 44 45 46 47 48 49 50
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
    const paddle::platform::complex<float>* x,
    const paddle::platform::complex<float>* y,
    const paddle::platform::complex<float>* out,
    const paddle::platform::complex<float>* dout, int64_t size,
    paddle::platform::complex<float>* dx,
    paddle::platform::complex<float>* dy) {
51 52 53
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
54 55 56 57
    paddle::platform::complex<float> o = dout[col];
    paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
    paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
                                                    -(out[col] / y[col]).imag);
58 59 60 61 62 63 64
    dx[col] = o / y_conj;
    dy[col] = -o * out_div_y_conj;
    col += blockDim.x * gridDim.x;
  }
}

template <>
65 66 67 68 69 70 71 72
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
    const paddle::platform::complex<double>* x,
    const paddle::platform::complex<double>* y,
    const paddle::platform::complex<double>* out,
    const paddle::platform::complex<double>* dout, int64_t size,
    paddle::platform::complex<double>* dx,
    paddle::platform::complex<double>* dy) {
73 74 75
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
76 77 78 79
    paddle::platform::complex<double> o = dout[col];
    paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
    paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
                                                     -(out[col] / y[col]).imag);
80 81 82 83 84 85
    dx[col] = o / y_conj;
    dy[col] = -o * out_div_y_conj;
    col += blockDim.x * gridDim.x;
  }
}

86 87 88 89 90 91 92 93
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
94
  dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
95
  auto size = x->numel();
96
  dim3 grid_size =
97
      dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
98
  SimpleElemwiseDivGradCUDAKernel<
99
      T><<<grid_size, block_size, 0,
100 101 102 103 104 105 106
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
      x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
      dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
}

}  // namespace operators
}  // namespace paddle
G
gongweibao 已提交
107

Q
QI JUN 已提交
108
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
109
    elementwise_div,
Q
QI JUN 已提交
110
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
W
Wu Yi 已提交
111 112
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
                              paddle::platform::float16>,
Q
QI JUN 已提交
113 114
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
115 116
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
117
                              paddle::platform::complex<float>>,
118
    ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
119
                              paddle::platform::complex<double>>);
Q
QI JUN 已提交
120
REGISTER_OP_CUDA_KERNEL(
G
gongweibao 已提交
121
    elementwise_div_grad,
Q
QI JUN 已提交
122
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
W
Wu Yi 已提交
123 124
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
                                  paddle::platform::float16>,
Q
QI JUN 已提交
125 126
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
127 128
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
129
                                  paddle::platform::complex<float>>,
Q
QI JUN 已提交
130
    ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
131
                                  paddle::platform::complex<double>>);
132 133 134 135
REGISTER_OP_CUDA_KERNEL(
    elementwise_div_grad_grad,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        float>,
136 137
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        paddle::platform::float16>,
138 139 140 141 142
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        double>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
                                        int>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
143 144
                                        int64_t>,
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
145
                                        paddle::platform::complex<float>>,
146
    ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
147
                                        paddle::platform::complex<double>>);