custom_relu_op.cu 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
                                         data_t* y,
                                         const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
23
    y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
24 25 26 27 28 29 30 31 32 33
  }
}

template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
                                          const data_t* y,
                                          data_t* dx,
                                          const int num) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
34 35
    dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.)
                                                    : static_cast<data_t>(0.));
36 37 38 39
  }
}

std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
40
  auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
41 42 43 44

  int numel = x.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
45
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
46
      x.type(), "relu_cuda_forward_kernel", ([&] {
47
        relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
48
            x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
49 50
      }));

51
  return {out};
52 53 54 55 56
}

std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
                                               const paddle::Tensor& out,
                                               const paddle::Tensor& grad_out) {
57
  auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape());
58 59 60 61

  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
62
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
63
      out.type(), "relu_cuda_backward_kernel", ([&] {
64
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
65 66 67 68 69 70 71 72
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x.mutable_data<data_t>(x.place()),
            numel);
      }));

  return {grad_x};
}
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

std::vector<paddle::Tensor> relu_cuda_backward_without_x(
    const paddle::Tensor& out, const paddle::Tensor& grad_out) {
  auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());

  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      out.type(), "relu_cuda_backward_kernel", ([&] {
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x.mutable_data<data_t>(out.place()),
            numel);
      }));

  return {grad_x};
}
92 93 94 95 96

void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
  int numel = x.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
97
  out->reshape(x.shape());
98 99 100 101 102 103 104 105 106 107 108 109 110 111
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      x.type(), "relu_cuda_forward_kernel", ([&] {
        relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
            x.data<data_t>(), out->mutable_data<data_t>(x.place()), numel);
      }));
}

void relu_cuda_backward_out(const paddle::Tensor& x,
                            const paddle::Tensor& out,
                            const paddle::Tensor& grad_out,
                            paddle::Tensor* grad_x) {
  int numel = out.size();
  int block = 512;
  int grid = (numel + block - 1) / block;
112
  grad_x->reshape(x.shape());
113 114 115 116 117 118 119 120 121
  PD_DISPATCH_FLOATING_AND_HALF_TYPES(
      out.type(), "relu_cuda_backward_kernel", ([&] {
        relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
            grad_out.data<data_t>(),
            out.data<data_t>(),
            grad_x->mutable_data<data_t>(x.place()),
            numel);
      }));
}