custom_relu_op.cc 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>

#include "paddle/extension.h"

20 21 22
#define CHECK_CPU_INPUT(x) \
  PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")

23 24 25 26
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
                             data_t* out_data,
                             int64_t x_numel) {
27 28
  PD_CHECK(x_data != nullptr, "x_data is nullptr.");
  PD_CHECK(out_data != nullptr, "out_data is nullptr.");
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
  for (int i = 0; i < x_numel; ++i) {
    out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
  }
}

template <typename data_t>
void relu_cpu_backward_kernel(const data_t* grad_out_data,
                              const data_t* out_data,
                              data_t* grad_x_data,
                              int64_t out_numel) {
  for (int i = 0; i < out_numel; ++i) {
    grad_x_data[i] =
        grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
  }
}

45 46 47 48 49 50 51 52 53 54 55
template <typename data_t>
void relu_cpu_double_backward_kernel(const data_t* out_data,
                                     const data_t* ddx_data,
                                     data_t* ddout_data,
                                     int64_t ddout_numel) {
  for (int64_t i = 0; i < ddout_numel; ++i) {
    ddout_data[i] =
        ddx_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
  }
}

56
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
57
  auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
58 59 60 61 62 63 64

  PD_DISPATCH_FLOATING_TYPES(
      x.type(), "relu_cpu_forward", ([&] {
        relu_cpu_forward_kernel<data_t>(
            x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
      }));

65
  return {out};
66 67 68 69 70
}

std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
                                              const paddle::Tensor& out,
                                              const paddle::Tensor& grad_out) {
71
  auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
72 73 74 75 76 77 78 79 80 81 82 83

  PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
                               relu_cpu_backward_kernel<data_t>(
                                   grad_out.data<data_t>(),
                                   out.data<data_t>(),
                                   grad_x.mutable_data<data_t>(x.place()),
                                   out.size());
                             }));

  return {grad_x};
}

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
std::vector<paddle::Tensor> relu_cpu_double_backward(
    const paddle::Tensor& out, const paddle::Tensor& ddx) {
  CHECK_CPU_INPUT(out);
  CHECK_CPU_INPUT(ddx);
  auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());

  PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] {
                               relu_cpu_double_backward_kernel<data_t>(
                                   out.data<data_t>(),
                                   ddx.data<data_t>(),
                                   ddout.mutable_data<data_t>(out.place()),
                                   ddout.size());
                             }));

  std::cout << "Debug info: run relu cpu double backward success." << std::endl;

  return {ddout};
}

103 104 105 106
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
                                               const paddle::Tensor& out,
                                               const paddle::Tensor& grad_out);
107 108
std::vector<paddle::Tensor> relu_cuda_double_backward(
    const paddle::Tensor& out, const paddle::Tensor& ddx);
109 110 111 112 113 114 115 116

std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
  // TODO(chenweihang): Check Input
  if (x.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_forward(x);
  } else if (x.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_forward(x);
  } else {
117
    PD_THROW("Not implemented.");
118 119 120 121 122 123 124 125 126 127 128 129
  }
}

std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
                                         const paddle::Tensor& out,
                                         const paddle::Tensor& grad_out) {
  // TODO(chenweihang): Check Input
  if (x.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_backward(x, out, grad_out);
  } else if (x.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_backward(x, out, grad_out);
  } else {
130
    PD_THROW("Not implemented.");
131 132 133
  }
}

134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
std::vector<paddle::Tensor> ReluDoubleBackward(const paddle::Tensor& out,
                                               const paddle::Tensor& ddx) {
  if (out.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_double_backward(out, ddx);
  } else if (out.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_double_backward(out, ddx);
  } else {
    PD_THROW("Not implemented.");
  }
}

std::vector<std::vector<int64_t>> ReluDoubleBackwardInferShape(
    const std::vector<int64_t>& out_shape,
    const std::vector<int64_t>& ddx_shape) {
  return {out_shape};
}

151
PD_BUILD_OP(custom_relu)
152
    .Inputs({"X"})
153
    .Outputs({"Out"})
154 155 156
    .SetKernelFn(PD_KERNEL(ReluForward));

PD_BUILD_GRAD_OP(custom_relu)
157 158 159
    .Inputs({"X", "Out", paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .SetKernelFn(PD_KERNEL(ReluBackward));
160

161 162 163 164 165 166
PD_BUILD_DOUBLE_GRAD_OP(custom_relu)
    .Inputs({"Out", paddle::Grad(paddle::Grad("X"))})
    .Outputs({paddle::Grad(paddle::Grad("Out"))})
    .SetKernelFn(PD_KERNEL(ReluDoubleBackward))
    .SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape));

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
std::vector<paddle::Tensor> relu_cpu_backward_without_x(
    const paddle::Tensor& out, const paddle::Tensor& grad_out) {
  auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());

  PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
                               relu_cpu_backward_kernel<data_t>(
                                   grad_out.data<data_t>(),
                                   out.data<data_t>(),
                                   grad_x.mutable_data<data_t>(out.place()),
                                   out.size());
                             }));

  return {grad_x};
}

std::vector<paddle::Tensor> relu_cuda_backward_without_x(
    const paddle::Tensor& out, const paddle::Tensor& grad_out);

std::vector<paddle::Tensor> ReluBackwardWithoutX(
    const paddle::Tensor& out, const paddle::Tensor& grad_out) {
  if (out.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_backward_without_x(out, grad_out);
  } else if (out.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_backward_without_x(out, grad_out);
  } else {
    PD_THROW("Not implemented.");
  }
}

std::vector<std::vector<int64_t>> ReluBackwardWithoutXInferShape(
    const std::vector<int64_t>& out_shape,
    const std::vector<int64_t>& grad_out_shape) {
  return {out_shape};
}

PD_BUILD_OP(custom_relu_no_x_in_backward)
    .Inputs({"X"})
    .Outputs({"Out"})
    .SetKernelFn(PD_KERNEL(ReluForward));

PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
    .Inputs({"Out", paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
    .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
212 213

void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
214
  out->reshape(x.shape());
215 216 217 218 219 220 221 222 223 224 225
  PD_DISPATCH_FLOATING_TYPES(
      x.type(), "relu_cpu_forward", ([&] {
        relu_cpu_forward_kernel<data_t>(
            x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.size());
      }));
}

void relu_cpu_backward_out(const paddle::Tensor& x,
                           const paddle::Tensor& out,
                           const paddle::Tensor& grad_out,
                           paddle::Tensor* grad_x) {
226
  grad_x->reshape(x.shape());
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
  PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
                               relu_cpu_backward_kernel<data_t>(
                                   grad_out.data<data_t>(),
                                   out.data<data_t>(),
                                   grad_x->mutable_data<data_t>(x.place()),
                                   out.size());
                             }));
}

void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out);
void relu_cuda_backward_out(const paddle::Tensor& x,
                            const paddle::Tensor& out,
                            const paddle::Tensor& grad_out,
                            paddle::Tensor* grad_x);

void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) {
  if (x.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_forward_out(x, out);
  } else if (x.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_forward_out(x, out);
  } else {
    PD_THROW("Not implemented.");
  }
}

void ReluBackwardOut(const paddle::Tensor& x,
                     const paddle::Tensor& out,
                     const paddle::Tensor& grad_out,
                     paddle::Tensor* grad_x) {
  if (x.place() == paddle::PlaceType::kCPU) {
    return relu_cpu_backward_out(x, out, grad_out, grad_x);
  } else if (x.place() == paddle::PlaceType::kGPU) {
    return relu_cuda_backward_out(x, out, grad_out, grad_x);
  } else {
    PD_THROW("Not implemented.");
  }
}

PD_BUILD_OP(custom_relu_out)
    .Inputs({"X"})
    .Outputs({"Out"})
    .SetKernelFn(PD_KERNEL(ReluForwardOut));

PD_BUILD_GRAD_OP(custom_relu_out)
    .Inputs({"X", "Out", paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .SetKernelFn(PD_KERNEL(ReluBackwardOut));