elementwise_mul_op.cu 7.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
16
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
17
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
18
#include "paddle/fluid/platform/complex.h"
W
Wu Yi 已提交
19
#include "paddle/fluid/platform/float16.h"
20 21

namespace ops = paddle::operators;
W
Wu Yi 已提交
22
namespace plat = paddle::platform;
23

24 25 26
namespace paddle {
namespace operators {

27
template <typename T>
28 29 30
struct CudaMulFunctor {
  inline HOSTDEVICE T operator()(const T* args) const {
    return args[0] * args[1];
31 32 33
  }
};

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
template <typename T>
class ElementwiseMulKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    int axis = -1;
    auto x_var = ctx.InputVar("X");
    PADDLE_ENFORCE_NOT_NULL(
        x_var, platform::errors::InvalidArgument(
                   "Cannot get input Variable X, Variable name = %s.",
                   ctx.InputName("X")));
    auto* y = ctx.Input<framework::LoDTensor>("Y");

    framework::Tensor x, *z;
    std::vector<const framework::Tensor*> ins;
    std::vector<framework::Tensor*> outs;
    const auto& cuda_ctx =
        ctx.template device_context<platform::CUDADeviceContext>();

    if (x_var->IsType<framework::LoDTensor>()) {
      x = x_var->Get<framework::LoDTensor>();
      z = ctx.Output<framework::LoDTensor>("Out");
      axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
    } else if (x_var->IsType<framework::SelectedRows>()) {
      PADDLE_ENFORCE_EQ(y->dims().size() == 1 && y->dims()[0] == 1, true,
                        platform::errors::InvalidArgument(
                            "For elementwise_op, if X is Sparse, Y must be "
                            "scalar. But reveived the size of Y = %s.",
                            y->dims().size()));
      auto& x_sele = x_var->Get<framework::SelectedRows>();
      auto out_sele = ctx.Output<framework::SelectedRows>("Out");
      x = x_sele.value();
      out_sele->set_rows(x_sele.rows());
      out_sele->set_height(x_sele.height());
      out_sele->mutable_value()->Resize(x_sele.value().dims());
      out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
      z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
      z->mutable_data<T>(ctx.GetPlace());
      outs.emplace_back(z);
      ins.emplace_back(&x);
      ins.emplace_back(y);

      axis = ctx.HasAttr("axis") ? ctx.Attr<int>("axis") : -1;
      axis = axis == -1 ? std::abs(y->dims().size() - x.dims().size()) : axis;
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "X's type[%s] is not supported by elementwise_op. X's type should be "
          "LoDTensor or SelectedRows.",
          framework::ToTypeName(x_var->Type())));
    }

    LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
        cuda_ctx, ins, &outs, axis, CudaMulFunctor<T>());
87 88 89
  }
};

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
template <typename T>
static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
                                                       const T* out,
                                                       const T* dout,
                                                       int64_t size, T* dx,
                                                       T* dy) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
    T o = dout[col];
    dx[col] = y[col] * o;
    dy[col] = x[col] * o;
    col += blockDim.x * gridDim.x;
  }
}

106
template <>
107 108 109 110
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<float>>(
    const plat::complex<float>* x, const plat::complex<float>* y,
    const plat::complex<float>* out, const plat::complex<float>* dout,
    int64_t size, plat::complex<float>* dx, plat::complex<float>* dy) {
111 112 113
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
114 115 116
    plat::complex<float> o = dout[col];
    dx[col] = plat::complex<float>(y[col].real, -y[col].imag) * o;
    dy[col] = plat::complex<float>(x[col].real, -x[col].imag) * o;
117 118 119 120 121
    col += blockDim.x * gridDim.x;
  }
}

template <>
122 123 124 125
__global__ void SimpleElemwiseMulGradCUDAKernel<plat::complex<double>>(
    const plat::complex<double>* x, const plat::complex<double>* y,
    const plat::complex<double>* out, const plat::complex<double>* dout,
    int64_t size, plat::complex<double>* dx, plat::complex<double>* dy) {
126 127 128
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  while (col < size) {
129 130 131
    plat::complex<double> o = dout[col];
    dx[col] = plat::complex<double>(y[col].real, -y[col].imag) * o;
    dy[col] = plat::complex<double>(x[col].real, -x[col].imag) * o;
132 133 134 135
    col += blockDim.x * gridDim.x;
  }
}

136 137 138 139 140 141 142 143 144 145
template <typename DeviceContext, typename T>
typename std::enable_if<
    std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_mul_grad(const framework::ExecutionContext& ctx,
                     const framework::Tensor* x, const framework::Tensor* y,
                     const framework::Tensor* out,
                     const framework::Tensor* dout, framework::Tensor* dx,
                     framework::Tensor* dy) {
  dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
  auto size = x->numel();
146
  dim3 grid_size =
147 148
      dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
  SimpleElemwiseMulGradCUDAKernel<
149
      T><<<grid_size, block_size, 0,
150 151 152 153
           ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
      x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
      dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
}
154 155 156 157

}  // namespace operators
}  // namespace paddle

Q
QI JUN 已提交
158
REGISTER_OP_CUDA_KERNEL(
W
Wu Yi 已提交
159 160 161 162
    elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
163
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
164 165
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<float>>,
    ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex<double>>);
Q
QI JUN 已提交
166
REGISTER_OP_CUDA_KERNEL(
167
    elementwise_mul_grad,
W
Wu Yi 已提交
168 169 170 171
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
172
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
173 174 175 176
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
                                  plat::complex<float>>,
    ops::ElementwiseMulGradKernel<plat::CUDADeviceContext,
                                  plat::complex<double>>);
177 178 179 180 181
REGISTER_OP_CUDA_KERNEL(
    elementwise_mul_grad_grad,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
182
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
183
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
184
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
185
                                        plat::complex<float>>,
186
    ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
187
                                        plat::complex<double>>);