Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0ee76f92
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0ee76f92
编写于
3月 26, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add double grad op example (#40963)
上级
b94cf842
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
137 addition
and
1 deletion
+137
-1
paddle/phi/api/lib/op_meta_info.cc
paddle/phi/api/lib/op_meta_info.cc
+1
-0
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
+58
-0
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
+40
-0
python/paddle/fluid/tests/custom_op/custom_relu_setup.py
python/paddle/fluid/tests/custom_op/custom_relu_setup.py
+2
-1
python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
...paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
+36
-0
未找到文件。
paddle/phi/api/lib/op_meta_info.cc
浏览文件 @
0ee76f92
...
...
@@ -192,6 +192,7 @@ OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name, size_t index) {
break
;
case
2
:
name_
=
name_
+
"_grad_grad"
;
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Not support index `%d` when construct OpMetaInfoBuilder, "
...
...
python/paddle/fluid/tests/custom_op/custom_relu_op.cc
浏览文件 @
0ee76f92
...
...
@@ -17,6 +17,9 @@
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
template
<
typename
data_t
>
void
relu_cpu_forward_kernel
(
const
data_t
*
x_data
,
data_t
*
out_data
,
...
...
@@ -39,6 +42,17 @@ void relu_cpu_backward_kernel(const data_t* grad_out_data,
}
}
template
<
typename
data_t
>
void
relu_cpu_double_backward_kernel
(
const
data_t
*
out_data
,
const
data_t
*
ddx_data
,
data_t
*
ddout_data
,
int64_t
ddout_numel
)
{
for
(
int64_t
i
=
0
;
i
<
ddout_numel
;
++
i
)
{
ddout_data
[
i
]
=
ddx_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_forward
(
const
paddle
::
Tensor
&
x
)
{
auto
out
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
,
x
.
shape
());
...
...
@@ -67,10 +81,31 @@ std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
return
{
grad_x
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
CHECK_CPU_INPUT
(
out
);
CHECK_CPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
,
out
.
shape
());
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_double_backward"
,
([
&
]
{
relu_cpu_double_backward_kernel
<
data_t
>
(
out
.
data
<
data_t
>
(),
ddx
.
data
<
data_t
>
(),
ddout
.
mutable_data
<
data_t
>
(
out
.
place
()),
ddout
.
size
());
}));
std
::
cout
<<
"Debug info: run relu cpu double backward success."
<<
std
::
endl
;
return
{
ddout
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_forward
(
const
paddle
::
Tensor
&
x
);
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_backward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
);
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
);
std
::
vector
<
paddle
::
Tensor
>
ReluForward
(
const
paddle
::
Tensor
&
x
)
{
// TODO(chenweihang): Check Input
...
...
@@ -96,6 +131,23 @@ std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
}
}
std
::
vector
<
paddle
::
Tensor
>
ReluDoubleBackward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
if
(
out
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
return
relu_cpu_double_backward
(
out
,
ddx
);
}
else
if
(
out
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
return
relu_cuda_double_backward
(
out
,
ddx
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
std
::
vector
<
int64_t
>>
ReluDoubleBackwardInferShape
(
const
std
::
vector
<
int64_t
>&
out_shape
,
const
std
::
vector
<
int64_t
>&
ddx_shape
)
{
return
{
out_shape
};
}
PD_BUILD_OP
(
custom_relu
)
.
Inputs
({
"X"
})
.
Outputs
({
"Out"
})
...
...
@@ -106,6 +158,12 @@ PD_BUILD_GRAD_OP(custom_relu)
.
Outputs
({
paddle
::
Grad
(
"X"
)})
.
SetKernelFn
(
PD_KERNEL
(
ReluBackward
));
PD_BUILD_DOUBLE_GRAD_OP
(
custom_relu
)
.
Inputs
({
"Out"
,
paddle
::
Grad
(
paddle
::
Grad
(
"X"
))})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Grad
(
"Out"
))})
.
SetKernelFn
(
PD_KERNEL
(
ReluDoubleBackward
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
ReluDoubleBackwardInferShape
));
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_backward_without_x
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
,
out
.
shape
());
...
...
python/paddle/fluid/tests/custom_op/custom_relu_op.cu
浏览文件 @
0ee76f92
...
...
@@ -14,6 +14,9 @@
#include "paddle/extension.h"
#define CHECK_GPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
template
<
typename
data_t
>
__global__
void
relu_cuda_forward_kernel
(
const
data_t
*
x
,
data_t
*
y
,
...
...
@@ -36,6 +39,19 @@ __global__ void relu_cuda_backward_kernel(const data_t* dy,
}
}
template
<
typename
data_t
>
__global__
void
relu_cuda_double_backward_kernel
(
const
data_t
*
out_data
,
const
data_t
*
ddx_data
,
data_t
*
ddout_data
,
int64_t
num
)
{
int64_t
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
num
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
ddout_data
[
i
]
=
ddx_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0.
)
?
static_cast
<
data_t
>
(
1.
)
:
static_cast
<
data_t
>
(
0.
));
}
}
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_forward
(
const
paddle
::
Tensor
&
x
)
{
auto
out
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
x
.
shape
());
...
...
@@ -71,6 +87,30 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
return
{
grad_x
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
CHECK_GPU_INPUT
(
out
);
CHECK_GPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
out
.
shape
());
int64_t
numel
=
out
.
size
();
int64_t
block
=
512
;
int64_t
grid
=
(
numel
+
block
-
1
)
/
block
;
PD_DISPATCH_FLOATING_AND_HALF_TYPES
(
out
.
type
(),
"relu_cuda_double_backward_kernel"
,
([
&
]
{
relu_cuda_double_backward_kernel
<
data_t
><<<
grid
,
block
,
0
,
out
.
stream
()
>>>
(
out
.
data
<
data_t
>
(),
ddx
.
data
<
data_t
>
(),
ddout
.
mutable_data
<
data_t
>
(
out
.
place
()),
numel
);
}));
std
::
cout
<<
"Debug info: run relu gpu double backward success."
<<
std
::
endl
;
return
{
ddout
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cuda_backward_without_x
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
out
.
shape
());
...
...
python/paddle/fluid/tests/custom_op/custom_relu_setup.py
浏览文件 @
0ee76f92
...
...
@@ -31,4 +31,5 @@ setup(
ext_modules
=
Extension
(
# test for not specific name here.
sources
=
sources
,
# test for multi ops
include_dirs
=
paddle_includes
,
extra_compile_args
=
extra_compile_args
))
extra_compile_args
=
extra_compile_args
,
verbose
=
True
))
python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
浏览文件 @
0ee76f92
...
...
@@ -138,6 +138,23 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix):
return
predict_v
def
custom_relu_double_grad_dynamic
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
set_device
(
device
)
t
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
,
stop_gradient
=
False
)
out
=
func
(
t
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
t
)
out
.
stop_gradient
=
False
dx
=
paddle
.
grad
(
outputs
=
[
out
],
inputs
=
[
t
],
create_graph
=
True
,
retain_graph
=
True
)
dx
[
0
].
backward
()
assert
dx
[
0
].
grad
is
not
None
return
dx
[
0
].
numpy
(),
dx
[
0
].
grad
.
numpy
()
class
TestNewCustomOpSetUpInstall
(
unittest
.
TestCase
):
def
setUp
(
self
):
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
...
@@ -293,6 +310,25 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
predict
,
predict_infer
))
paddle
.
disable_static
()
def
test_func_double_grad_dynamic
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
if
device
==
'cpu'
and
dtype
==
'float16'
:
continue
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
,
dx_grad
=
custom_relu_double_grad_dynamic
(
self
.
custom_ops
[
0
],
device
,
dtype
,
x
)
pd_out
,
pd_dx_grad
=
custom_relu_double_grad_dynamic
(
self
.
custom_ops
[
0
],
device
,
dtype
,
x
,
False
)
self
.
assertTrue
(
np
.
array_equal
(
out
,
pd_out
),
"custom op out: {},
\n
paddle api out: {}"
.
format
(
out
,
pd_out
))
self
.
assertTrue
(
np
.
array_equal
(
dx_grad
,
pd_dx_grad
),
"custom op dx grad: {},
\n
paddle api dx grad: {}"
.
format
(
dx_grad
,
pd_dx_grad
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录