未验证 提交 2a17e3c1 编写于 作者: C Chen Weihang 提交者: GitHub

update relu custom op demo (#43173)

上级 19b4ff47
......@@ -17,8 +17,7 @@
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
......@@ -26,7 +25,7 @@ void relu_cpu_forward_kernel(const data_t* x_data,
int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int i = 0; i < x_numel; ++i) {
for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}
......@@ -36,7 +35,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
for (int i = 0; i < out_numel; ++i) {
for (int64_t i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
......@@ -54,12 +53,12 @@ void relu_cpu_double_backward_kernel(const data_t* out_data,
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::empty(x.shape(), x.dtype(), x.place());
auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
x.data<data_t>(), out.data<data_t>(), x.numel());
}));
return {out};
......@@ -68,13 +67,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place());
auto grad_x = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
grad_x.data<data_t>(),
out.size());
}));
......@@ -108,9 +107,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx);
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.place() == paddle::PlaceType::kCPU) {
if (x.is_cpu()) {
return relu_cpu_forward(x);
} else if (x.place() == paddle::PlaceType::kGPU) {
} else if (x.is_gpu()) {
return relu_cuda_forward(x);
} else {
PD_THROW("Not implemented.");
......@@ -120,10 +119,9 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
if (x.is_cpu()) {
return relu_cpu_backward(x, out, grad_out);
} else if (x.place() == paddle::PlaceType::kGPU) {
} else if (x.is_gpu()) {
return relu_cuda_backward(x, out, grad_out);
} else {
PD_THROW("Not implemented.");
......@@ -214,7 +212,7 @@ void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.size());
x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.numel());
}));
}
......
......@@ -14,15 +14,14 @@
#include "paddle/extension.h"
#define CHECK_GPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = x[i] > static_cast<data_t>(0.) ? x[i] : static_cast<data_t>(0.);
}
}
......@@ -31,9 +30,9 @@ template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
const data_t* y,
data_t* dx,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > static_cast<data_t>(0.) ? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
}
......@@ -54,15 +53,15 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
CHECK_GPU_INPUT(x);
auto out = paddle::empty(x.shape(), x.dtype(), x.place());
auto out = paddle::empty_like(x);
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
int64_t numel = x.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
x.data<data_t>(), out.data<data_t>(), numel);
}));
return {out};
......@@ -74,11 +73,11 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
CHECK_GPU_INPUT(x);
CHECK_GPU_INPUT(out);
CHECK_GPU_INPUT(grad_out);
auto grad_x = paddle::empty(x.shape(), x.dtype(), x.place());
auto grad_x = paddle::empty_like(x);
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
int64_t numel = out.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
......@@ -97,7 +96,7 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
CHECK_GPU_INPUT(ddx);
auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
int64_t numel = out.size();
int64_t numel = out.numel();
int64_t block = 512;
int64_t grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
......@@ -119,7 +118,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::empty(out.shape(), out.dtype(), out.place());
int numel = out.size();
int numel = out.numel();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
......@@ -135,7 +134,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
}
void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
int numel = x.size();
int numel = x.numel();
int block = 512;
int grid = (numel + block - 1) / block;
out->reshape(x.shape());
......@@ -150,7 +149,7 @@ void relu_cuda_backward_out(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
paddle::Tensor* grad_x) {
int numel = out.size();
int numel = out.numel();
int block = 512;
int grid = (numel + block - 1) / block;
grad_x->reshape(x.shape());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册