Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2a17e3c1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2a17e3c1
编写于
6月 05, 2022
作者:
C
Chen Weihang
提交者:
GitHub
6月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update relu custom op demo (#43173)
上级
19b4ff47
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
32 addition
and
35 deletion
+32
-35
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
+12
-14
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
+20
-21
未找到文件。
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
浏览文件 @
2a17e3c1
...
...
@@ -17,8 +17,7 @@
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
template
<
typename
data_t
>
void
relu_cpu_forward_kernel
(
const
data_t
*
x_data
,
...
...
@@ -26,7 +25,7 @@ void relu_cpu_forward_kernel(const data_t* x_data,
int64_t
x_numel
)
{
PD_CHECK
(
x_data
!=
nullptr
,
"x_data is nullptr."
);
PD_CHECK
(
out_data
!=
nullptr
,
"out_data is nullptr."
);
for
(
int
i
=
0
;
i
<
x_numel
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
x_numel
;
++
i
)
{
out_data
[
i
]
=
std
::
max
(
static_cast
<
data_t
>
(
0.
),
x_data
[
i
]);
}
}
...
...
@@ -36,7 +35,7 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
const
data_t
*
out_data
,
data_t
*
grad_x_data
,
int64_t
out_numel
)
{
for
(
int
i
=
0
;
i
<
out_numel
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
out_numel
;
++
i
)
{
grad_x_data
[
i
]
=
grad_out_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
...
...
@@ -54,12 +53,12 @@ void relu_cpu_double_backward_kernel(const data_t* out_data,
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_forward
(
const
paddle
::
Tensor
&
x
)
{
auto
out
=
paddle
::
empty
(
x
.
shape
(),
x
.
dtype
(),
x
.
place
()
);
auto
out
=
paddle
::
empty
_like
(
x
);
PD_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"relu_cpu_forward"
,
([
&
]
{
relu_cpu_forward_kernel
<
data_t
>
(
x
.
data
<
data_t
>
(),
out
.
mutable_data
<
data_t
>
(
x
.
place
()),
x
.
size
());
x
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(),
x
.
numel
());
}));
return
{
out
};
...
...
@@ -68,13 +67,13 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_backward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
empty
(
x
.
shape
(),
x
.
dtype
(),
x
.
place
()
);
auto
grad_x
=
paddle
::
empty
_like
(
x
);
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_backward"
,
([
&
]
{
relu_cpu_backward_kernel
<
data_t
>
(
grad_out
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(),
grad_x
.
mutable_data
<
data_t
>
(
x
.
place
()
),
grad_x
.
data
<
data_t
>
(
),
out
.
size
());
}));
...
...
@@ -108,9 +107,9 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
);
std
::
vector
<
paddle
::
Tensor
>
ReluForward
(
const
paddle
::
Tensor
&
x
)
{
if
(
x
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
if
(
x
.
is_cpu
()
)
{
return
relu_cpu_forward
(
x
);
}
else
if
(
x
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
}
else
if
(
x
.
is_gpu
()
)
{
return
relu_cuda_forward
(
x
);
}
else
{
PD_THROW
(
"Not implemented."
);
...
...
@@ -120,10 +119,9 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
std
::
vector
<
paddle
::
Tensor
>
ReluBackward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
// TODO(chenweihang): Check Input
if
(
x
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
if
(
x
.
is_cpu
())
{
return
relu_cpu_backward
(
x
,
out
,
grad_out
);
}
else
if
(
x
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
}
else
if
(
x
.
is_gpu
()
)
{
return
relu_cuda_backward
(
x
,
out
,
grad_out
);
}
else
{
PD_THROW
(
"Not implemented."
);
...
...
@@ -214,7 +212,7 @@ void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
PD_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"relu_cpu_forward"
,
([
&
]
{
relu_cpu_forward_kernel
<
data_t
>
(
x
.
data
<
data_t
>
(),
out
->
mutable_data
<
data_t
>
(
x
.
place
()),
x
.
size
());
x
.
data
<
data_t
>
(),
out
->
mutable_data
<
data_t
>
(
x
.
place
()),
x
.
numel
());
}));
}
...
...
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
浏览文件 @
2a17e3c1
...
...
@@ -14,15 +14,14 @@
#include "paddle/extension.h"
#define CHECK_GPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
template
<
typename
data_t
>
__global__
void
relu_cuda_forward_kernel
(
const
data_t
*
x
,
data_t
*
y
,
const
in
t
num
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
gid
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int64_
t
num
)
{
int
64_t
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
64_t
i
=
gid
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
y
[
i
]
=
x
[
i
]
>
static_cast
<
data_t
>
(
0.
)
?
x
[
i
]
:
static_cast
<
data_t
>
(
0.
);
}
}
...
...
@@ -31,9 +30,9 @@ template <typename data_t>
__global__
void
relu_cuda_backward_kernel
(
const
data_t
*
dy
,
const
data_t
*
y
,
data_t
*
dx
,
const
in
t
num
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
i
=
gid
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int64_
t
num
)
{
int
64_t
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
64_t
i
=
gid
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx
[
i
]
=
dy
[
i
]
*
(
y
[
i
]
>
static_cast
<
data_t
>
(
0.
)
?
static_cast
<
data_t
>
(
1.
)
:
static_cast
<
data_t
>
(
0.
));
}
...
...
@@ -54,15 +53,15 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_forward
(
const
paddle
::
Tensor
&
x
)
{
CHECK_GPU_INPUT
(
x
);
auto
out
=
paddle
::
empty
(
x
.
shape
(),
x
.
dtype
(),
x
.
place
()
);
auto
out
=
paddle
::
empty
_like
(
x
);
int
numel
=
x
.
size
();
int
block
=
512
;
int
grid
=
(
numel
+
block
-
1
)
/
block
;
int
64_t
numel
=
x
.
numel
();
int
64_t
block
=
512
;
int
64_t
grid
=
(
numel
+
block
-
1
)
/
block
;
PD_DISPATCH_FLOATING_AND_HALF_TYPES
(
x
.
type
(),
"relu_cuda_forward_kernel"
,
([
&
]
{
relu_cuda_forward_kernel
<
data_t
><<<
grid
,
block
,
0
,
x
.
stream
()
>>>
(
x
.
data
<
data_t
>
(),
out
.
mutable_data
<
data_t
>
(
x
.
place
()
),
numel
);
x
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(
),
numel
);
}));
return
{
out
};
...
...
@@ -74,11 +73,11 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
CHECK_GPU_INPUT
(
x
);
CHECK_GPU_INPUT
(
out
);
CHECK_GPU_INPUT
(
grad_out
);
auto
grad_x
=
paddle
::
empty
(
x
.
shape
(),
x
.
dtype
(),
x
.
place
()
);
auto
grad_x
=
paddle
::
empty
_like
(
x
);
int
numel
=
out
.
size
();
int
block
=
512
;
int
grid
=
(
numel
+
block
-
1
)
/
block
;
int
64_t
numel
=
out
.
numel
();
int
64_t
block
=
512
;
int
64_t
grid
=
(
numel
+
block
-
1
)
/
block
;
PD_DISPATCH_FLOATING_AND_HALF_TYPES
(
out
.
type
(),
"relu_cuda_backward_kernel"
,
([
&
]
{
relu_cuda_backward_kernel
<
data_t
><<<
grid
,
block
,
0
,
x
.
stream
()
>>>
(
...
...
@@ -97,7 +96,7 @@ std::vector<paddle::Tensor> relu_cuda_double_backward(
CHECK_GPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
empty
(
out
.
shape
(),
out
.
dtype
(),
out
.
place
());
int64_t
numel
=
out
.
size
();
int64_t
numel
=
out
.
numel
();
int64_t
block
=
512
;
int64_t
grid
=
(
numel
+
block
-
1
)
/
block
;
PD_DISPATCH_FLOATING_AND_HALF_TYPES
(
...
...
@@ -119,7 +118,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
empty
(
out
.
shape
(),
out
.
dtype
(),
out
.
place
());
int
numel
=
out
.
size
();
int
numel
=
out
.
numel
();
int
block
=
512
;
int
grid
=
(
numel
+
block
-
1
)
/
block
;
PD_DISPATCH_FLOATING_AND_HALF_TYPES
(
...
...
@@ -135,7 +134,7 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
}
void
relu_cuda_forward_out
(
const
paddle
::
Tensor
&
x
,
paddle
::
Tensor
*
out
)
{
int
numel
=
x
.
size
();
int
numel
=
x
.
numel
();
int
block
=
512
;
int
grid
=
(
numel
+
block
-
1
)
/
block
;
out
->
reshape
(
x
.
shape
());
...
...
@@ -150,7 +149,7 @@ void relu_cuda_backward_out(const paddle::Tensor& x,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
,
paddle
::
Tensor
*
grad_x
)
{
int
numel
=
out
.
size
();
int
numel
=
out
.
numel
();
int
block
=
512
;
int
grid
=
(
numel
+
block
-
1
)
/
block
;
grad_x
->
reshape
(
x
.
shape
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录