custom_relu_op.cu 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

17 18 19
#define CHECK_GPU_INPUT(x) \
  PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")

20 21 22 23 24 25
template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
                                         data_t* y,
                                         const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
26
    y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
27 28 29 30 31 32 33 34 35 36
  }
}

template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
                                          const data_t* y,
                                          data_t* dx,
                                          const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
37 38
    dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.)
                                                    : static_cast<data_t>(0.));
39 40 41
  }
}

42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename data_t>
__global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
                                                 const data_t* ddx_data,
                                                 data_t* ddout_data,
                                                 int64_t num) {
  int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) {
    ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
                                       ? static_cast<data_t>(1.)
                                       : static_cast<data_t>(0.));
  }
}

55
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
56
  CHECK_GPU_INPUT(x);
57
  auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
58 59 60 61

  int numel = x.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
62
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
63
      x.type(), "relu_cuda_forward_kernel", ([&] {
64
        relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
65
            x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
66 67
      }));

68
  return {out};
69 70 71 72 73
}

std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
                                               const paddle::Tensor& out,
                                               const paddle::Tensor& grad_out) {
74 75 76
  CHECK_GPU_INPUT(x);
  CHECK_GPU_INPUT(out);
  CHECK_GPU_INPUT(grad_out);
77
  auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
78 79 80 81

  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
82
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
83
      out.type(), "relu_cuda_backward_kernel", ([&] {
84
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
85 86 87 88 89 90 91 92
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x.mutable_data<data_t>(x.place()),
            numel);
      }));

  return {grad_x};
}
93

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
std::vector<paddle::Tensor> relu_cuda_double_backward(
    const paddle::Tensor& out, const paddle::Tensor& ddx) {
  CHECK_GPU_INPUT(out);
  CHECK_GPU_INPUT(ddx);
  auto ddout = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());

  int64_t numel = out.size();
  int64_t block = 512;
  int64_t grid = (numel + block - 1) / block;
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      out.type(), "relu_cuda_double_backward_kernel", ([&] {
        relu_cuda_double_backward_kernel<
            data_t><<<grid, block, 0, out.stream()>>>(
            out.data<data_t>(),
            ddx.data<data_t>(),
            ddout.mutable_data<data_t>(out.place()),
            numel);
      }));

  std::cout << "Debug info: run relu gpu double backward success." << std::endl;

  return {ddout};
}

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
std::vector<paddle::Tensor> relu_cuda_backward_without_x(
    const paddle::Tensor& out, const paddle::Tensor& grad_out) {
  auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());

  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      out.type(), "relu_cuda_backward_kernel", ([&] {
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x.mutable_data<data_t>(out.place()),
            numel);
      }));

  return {grad_x};
}
136 137 138 139 140

void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
  int numel = x.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
141
  out->reshape(x.shape());
142 143 144 145 146 147 148 149 150 151 152 153 154 155
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      x.type(), "relu_cuda_forward_kernel", ([&] {
        relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
            x.data<data_t>(), out->mutable_data<data_t>(x.place()), numel);
      }));
}

void relu_cuda_backward_out(const paddle::Tensor& x,
                            const paddle::Tensor& out,
                            const paddle::Tensor& grad_out,
                            paddle::Tensor* grad_x) {
  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
156
  grad_x->reshape(x.shape());
157 158 159 160 161 162 163 164 165
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      out.type(), "relu_cuda_backward_kernel", ([&] {
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x->mutable_data<data_t>(x.place()),
            numel);
      }));
}