elementwise_mul_op.cu 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
16
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
17
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
18
#include "paddle/fluid/platform/complex.h"
W
Wu Yi 已提交
19
#include "paddle/fluid/platform/float16.h"
20 21

namespace ops = paddle::operators;
W
Wu Yi 已提交
22
namespace plat = paddle::platform;
23

24 25 26
namespace paddle {
namespace operators {

27 28 29 30 31
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
32
    framework::Tensor x_for_selectedrows;
33 34 35 36 37
    std::vector<const framework::Tensor*> ins;
    std::vector<framework::Tensor*> outs;
    const auto& cuda_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();

38
    int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
39
    LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
40
        cuda_ctx, ins, &outs, axis, MulFunctor<T>());
41 42 43
  }
};

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
template <typename T>
static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
                                                       const T* out,
                                                       const T* dout,
                                                       int64_t size, T* dx,
                                                       T* dy) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
    T o = dout[col];
    dx[col] = y[col] * o;
    dy[col] = x[col] * o;
    col += blockDim.x * gridDim.x;
  }
}

60
template <>
61 62 63 64
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>(
    const plat::complex<float>* x, const plat::complex<float>* y,
    const plat::complex<float>* out, const plat::complex<float>* dout,
    int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) {
65 66 67
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
68 69 70
    plat::complex<float> o = dout[col];
    dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
    dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
71 72 73 74 75
    col += blockDim.x * gridDim.x;
  }
}

template <>
76 77 78 79
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
    const plat::complex<double>* x, const plat::complex<double>* y,
    const plat::complex<double>* out, const plat::complex<double>* dout,
    int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) {
80 81 82
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
83 84 85
    plat::complex<double> o = dout[col];
    dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
    dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
86 87 88 89
    col += blockDim.x * gridDim.x;
  }
}

90 91 92 93 94 95 96 97 98 99
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_mul_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
  dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
  auto size = x->numel();
100
  dim3 grid_size =
101 102
      dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
  SimpleElemwiseMulGradCUDAKernel<
103
      T><<<grid_size, block_size, 0,
104 105 106 107
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
      x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
      dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
}
108 109 110 111

}  // namespace operators
}  // namespace paddle

Q
QI JUN 已提交
112
REGISTER_OP_CUDA_KERNEL(
W
Wu Yi 已提交
113 114 115 116
    elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
W
will-jl944 已提交
117
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, bool>,
118
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
119 120
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
Q
QI JUN 已提交
121
REGISTER_OP_CUDA_KERNEL(
122
    elementwise_mul_grad,
W
Wu Yi 已提交
123 124 125 126
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
W
will-jl944 已提交
127
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, bool>,
128
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
129 130 131 132
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
                                  plat::complex<float>>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
                                  plat::complex<double>>);
133 134 135 136 137
REGISTER_OP_CUDA_KERNEL(
    elementwise_mul_grad_grad,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
138
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
W
will-jl944 已提交
139
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, bool>,
140
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
141
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
142
                                        plat::complex<float>>,
143
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
144
                                        plat::complex<double>>);