Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
930b209e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
930b209e
编写于
12月 16, 2022
作者:
H
HongyuJia
提交者:
GitHub
12月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Custom Extension] Add xpu backward testcase (#49027)
* add xpu backward testcase * polish code * fix self.custom_op error
上级
1ca86fc6
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
383 addition
and
3 deletion
+383
-3
python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc
python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc
+131
-0
python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py
...le/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py
+252
-3
未找到文件。
python/paddle/fluid/tests/custom_op/custom_relu_op_xpu.cc
浏览文件 @
930b209e
...
...
@@ -31,6 +31,28 @@ void relu_cpu_forward_kernel(const data_t* x_data,
}
}
template
<
typename
data_t
>
void
relu_cpu_backward_kernel
(
const
data_t
*
grad_out_data
,
const
data_t
*
out_data
,
data_t
*
grad_x_data
,
int64_t
out_numel
)
{
for
(
int64_t
i
=
0
;
i
<
out_numel
;
++
i
)
{
grad_x_data
[
i
]
=
grad_out_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
}
template
<
typename
data_t
>
void
relu_cpu_double_backward_kernel
(
const
data_t
*
out_data
,
const
data_t
*
ddx_data
,
data_t
*
ddout_data
,
int64_t
ddout_numel
)
{
for
(
int64_t
i
=
0
;
i
<
ddout_numel
;
++
i
)
{
ddout_data
[
i
]
=
ddx_data
[
i
]
*
(
out_data
[
i
]
>
static_cast
<
data_t
>
(
0
)
?
1.
:
0.
);
}
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_forward
(
const
paddle
::
Tensor
&
x
)
{
CHECK_CPU_INPUT
(
x
);
auto
out
=
paddle
::
empty_like
(
x
);
...
...
@@ -44,12 +66,81 @@ std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
return
{
out
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_backward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
auto
grad_x
=
paddle
::
empty_like
(
x
);
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_backward"
,
([
&
]
{
relu_cpu_backward_kernel
<
data_t
>
(
grad_out
.
data
<
data_t
>
(),
out
.
data
<
data_t
>
(),
grad_x
.
data
<
data_t
>
(),
out
.
size
());
}));
return
{
grad_x
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_cpu_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
CHECK_CPU_INPUT
(
out
);
CHECK_CPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
empty
(
out
.
shape
(),
out
.
dtype
(),
out
.
place
());
PD_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"relu_cpu_double_backward"
,
([
&
]
{
relu_cpu_double_backward_kernel
<
data_t
>
(
out
.
data
<
data_t
>
(),
ddx
.
data
<
data_t
>
(),
ddout
.
mutable_data
<
data_t
>
(
out
.
place
()),
ddout
.
size
());
}));
std
::
cout
<<
"Debug info: run relu cpu double backward success."
<<
std
::
endl
;
return
{
ddout
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_xpu_forward
(
const
paddle
::
Tensor
&
x
)
{
CHECK_XPU_INPUT
(
x
);
auto
out
=
paddle
::
relu
(
x
);
return
{
out
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_xpu_backward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
CHECK_XPU_INPUT
(
x
);
CHECK_XPU_INPUT
(
out
);
CHECK_XPU_INPUT
(
grad_out
);
auto
grad_x
=
paddle
::
empty_like
(
x
,
x
.
dtype
(),
x
.
place
());
auto
ones
=
paddle
::
experimental
::
full_like
(
x
,
1.0
,
x
.
dtype
(),
x
.
place
());
auto
zeros
=
paddle
::
experimental
::
full_like
(
x
,
0.0
,
x
.
dtype
(),
x
.
place
());
auto
condition
=
paddle
::
experimental
::
greater_than
(
x
,
zeros
);
grad_x
=
paddle
::
multiply
(
grad_out
,
paddle
::
where
(
condition
,
ones
,
zeros
));
return
{
grad_x
};
}
std
::
vector
<
paddle
::
Tensor
>
relu_xpu_double_backward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
CHECK_XPU_INPUT
(
out
);
CHECK_XPU_INPUT
(
ddx
);
auto
ddout
=
paddle
::
empty
(
out
.
shape
(),
out
.
dtype
(),
out
.
place
());
auto
ones
=
paddle
::
experimental
::
full_like
(
out
,
1.0
,
out
.
dtype
(),
out
.
place
());
auto
zeros
=
paddle
::
experimental
::
full_like
(
out
,
0.0
,
out
.
dtype
(),
out
.
place
());
auto
condition
=
paddle
::
experimental
::
greater_than
(
out
,
zeros
);
ddout
=
paddle
::
multiply
(
ddx
,
paddle
::
where
(
condition
,
ones
,
zeros
));
std
::
cout
<<
"Debug info: run relu cpu double backward success."
<<
std
::
endl
;
return
{
ddout
};
}
std
::
vector
<
paddle
::
Tensor
>
ReluForward
(
const
paddle
::
Tensor
&
x
)
{
if
(
x
.
is_cpu
())
{
return
relu_cpu_forward
(
x
);
...
...
@@ -60,7 +151,47 @@ std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
}
}
std
::
vector
<
paddle
::
Tensor
>
ReluBackward
(
const
paddle
::
Tensor
&
x
,
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
grad_out
)
{
if
(
x
.
is_cpu
())
{
return
relu_cpu_backward
(
x
,
out
,
grad_out
);
}
else
if
(
x
.
is_xpu
())
{
return
relu_xpu_backward
(
x
,
out
,
grad_out
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
paddle
::
Tensor
>
ReluDoubleBackward
(
const
paddle
::
Tensor
&
out
,
const
paddle
::
Tensor
&
ddx
)
{
if
(
out
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
return
relu_cpu_double_backward
(
out
,
ddx
);
}
else
if
(
out
.
place
().
GetType
()
==
phi
::
AllocationType
::
XPU
)
{
return
relu_xpu_double_backward
(
out
,
ddx
);
}
else
{
PD_THROW
(
"Not implemented."
);
}
}
std
::
vector
<
std
::
vector
<
int64_t
>>
ReluDoubleBackwardInferShape
(
const
std
::
vector
<
int64_t
>&
out_shape
,
const
std
::
vector
<
int64_t
>&
ddx_shape
)
{
return
{
out_shape
};
}
PD_BUILD_OP
(
custom_relu
)
.
Inputs
({
"X"
})
.
Outputs
({
"Out"
})
.
SetKernelFn
(
PD_KERNEL
(
ReluForward
));
PD_BUILD_GRAD_OP
(
custom_relu
)
.
Inputs
({
"X"
,
"Out"
,
paddle
::
Grad
(
"Out"
)})
.
Outputs
({
paddle
::
Grad
(
"X"
)})
.
SetKernelFn
(
PD_KERNEL
(
ReluBackward
));
PD_BUILD_DOUBLE_GRAD_OP
(
custom_relu
)
.
Inputs
({
"Out"
,
paddle
::
Grad
(
paddle
::
Grad
(
"X"
))})
.
Outputs
({
paddle
::
Grad
(
paddle
::
Grad
(
"Out"
))})
.
SetKernelFn
(
PD_KERNEL
(
ReluDoubleBackward
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
ReluDoubleBackwardInferShape
));
python/paddle/fluid/tests/custom_op/test_custom_relu_op_xpu_setup.py
浏览文件 @
930b209e
...
...
@@ -23,15 +23,24 @@ import paddle
import
paddle.static
as
static
from
paddle.fluid.framework
import
_test_eager_guard
from
paddle.utils.cpp_extension.extension_utils
import
run_cmd
from
paddle.vision.transforms
import
Compose
,
Normalize
def
custom_relu_dynamic
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
set_device
(
device
)
t
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
)
t
.
stop_gradient
=
False
out
=
func
(
t
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
t
)
out
.
stop_gradient
=
False
out
.
backward
()
return
out
.
numpy
()
if
t
.
grad
is
None
:
return
out
.
numpy
(),
t
.
grad
else
:
return
out
.
numpy
(),
t
.
grad
.
numpy
()
def
custom_relu_static
(
...
...
@@ -43,7 +52,9 @@ def custom_relu_static(
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
program_guard
(
static
.
Program
()):
x
=
static
.
data
(
name
=
'X'
,
shape
=
[
None
,
8
],
dtype
=
dtype
)
x
.
stop_gradient
=
False
out
=
func
(
x
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
x
)
static
.
append_backward
(
out
)
exe
=
static
.
Executor
()
exe
.
run
(
static
.
default_startup_program
())
...
...
@@ -58,6 +69,97 @@ def custom_relu_static(
return
out_v
def
custom_relu_static_pe
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
enable_static
()
paddle
.
set_device
(
device
)
places
=
static
.
xpu_places
()
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
program_guard
(
static
.
Program
()):
x
=
static
.
data
(
name
=
'X'
,
shape
=
[
None
,
8
],
dtype
=
dtype
)
x
.
stop_gradient
=
False
out
=
func
(
x
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
x
)
static
.
append_backward
(
out
)
exe
=
static
.
Executor
()
exe
.
run
(
static
.
default_startup_program
())
# in static mode, x data has been covered by out
compiled_prog
=
static
.
CompiledProgram
(
static
.
default_main_program
()
).
with_data_parallel
(
loss_name
=
out
.
name
,
places
=
places
)
out_v
=
exe
.
run
(
compiled_prog
,
feed
=
{
'X'
:
np_x
},
fetch_list
=
[
out
.
name
]
)
paddle
.
disable_static
()
return
out_v
def
custom_relu_static_inference
(
func
,
device
,
np_data
,
np_label
,
path_prefix
):
paddle
.
set_device
(
device
)
with
static
.
scope_guard
(
static
.
Scope
()):
with
static
.
program_guard
(
static
.
Program
()):
# simple module
data
=
static
.
data
(
name
=
'data'
,
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
static
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
hidden
=
static
.
nn
.
fc
(
data
,
size
=
128
)
hidden
=
func
(
hidden
)
hidden
=
static
.
nn
.
fc
(
hidden
,
size
=
128
)
predict
=
static
.
nn
.
fc
(
hidden
,
size
=
10
,
activation
=
'softmax'
)
loss
=
paddle
.
nn
.
functional
.
cross_entropy
(
input
=
hidden
,
label
=
label
)
avg_loss
=
paddle
.
mean
(
loss
)
opt
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.1
)
opt
.
minimize
(
avg_loss
)
# run start up model
exe
=
static
.
Executor
()
exe
.
run
(
static
.
default_startup_program
())
# train
for
_
in
range
(
4
):
exe
.
run
(
static
.
default_main_program
(),
feed
=
{
'data'
:
np_data
,
'label'
:
np_label
},
fetch_list
=
[
avg_loss
],
)
# save inference model
static
.
save_inference_model
(
path_prefix
,
[
data
],
[
predict
],
exe
)
# get train predict value
predict_v
=
exe
.
run
(
static
.
default_main_program
(),
feed
=
{
'data'
:
np_data
,
'label'
:
np_label
},
fetch_list
=
[
predict
],
)
return
predict_v
def
custom_relu_double_grad_dynamic
(
func
,
device
,
dtype
,
np_x
,
use_func
=
True
):
paddle
.
set_device
(
device
)
t
=
paddle
.
to_tensor
(
np_x
,
dtype
=
dtype
,
stop_gradient
=
False
)
out
=
func
(
t
)
if
use_func
else
paddle
.
nn
.
functional
.
relu
(
t
)
out
.
stop_gradient
=
False
dx
=
paddle
.
grad
(
outputs
=
[
out
],
inputs
=
[
t
],
create_graph
=
True
,
retain_graph
=
True
)
dx
[
0
].
backward
()
assert
dx
[
0
].
grad
is
not
None
return
dx
[
0
].
numpy
(),
dx
[
0
].
grad
.
numpy
()
class
TestNewCustomOpSetUpInstall
(
unittest
.
TestCase
):
def
setUp
(
self
):
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
...
@@ -110,12 +212,30 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
),
)
def
test_static_pe
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
=
custom_relu_static_pe
(
self
.
custom_op
,
device
,
dtype
,
x
)
pd_out
=
custom_relu_static_pe
(
self
.
custom_op
,
device
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
def
func_dynamic
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
=
custom_relu_dynamic
(
self
.
custom_op
,
device
,
dtype
,
x
)
pd_out
=
custom_relu_dynamic
(
out
,
x_grad
=
custom_relu_dynamic
(
self
.
custom_op
,
device
,
dtype
,
x
)
pd_out
,
pd_x_grad
=
custom_relu_dynamic
(
self
.
custom_op
,
device
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
...
...
@@ -125,12 +245,141 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
out
,
pd_out
),
)
np
.
testing
.
assert_array_equal
(
x_grad
,
pd_x_grad
,
err_msg
=
'custom op x grad: {},
\n
paddle api x grad: {}'
.
format
(
x_grad
,
pd_x_grad
),
)
def
test_dynamic
(
self
):
with
_test_eager_guard
():
self
.
func_dynamic
()
self
.
func_dynamic
()
def
test_static_save_and_load_inference_model
(
self
):
paddle
.
enable_static
()
np_data
=
np
.
random
.
random
((
1
,
1
,
28
,
28
)).
astype
(
"float32"
)
np_label
=
np
.
random
.
random
((
1
,
1
)).
astype
(
"int64"
)
path_prefix
=
"self.custom_op_inference/custom_relu"
for
device
in
self
.
devices
:
predict
=
custom_relu_static_inference
(
self
.
custom_op
,
device
,
np_data
,
np_label
,
path_prefix
)
# load inference model
with
static
.
scope_guard
(
static
.
Scope
()):
exe
=
static
.
Executor
()
[
inference_program
,
feed_target_names
,
fetch_targets
,
]
=
static
.
load_inference_model
(
path_prefix
,
exe
)
predict_infer
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
np_data
},
fetch_list
=
fetch_targets
,
)
np
.
testing
.
assert_array_equal
(
predict
,
predict_infer
,
err_msg
=
'custom op predict: {},
\n
custom op infer predict: {}'
.
format
(
predict
,
predict_infer
),
)
paddle
.
disable_static
()
def
test_static_save_and_run_inference_predictor
(
self
):
paddle
.
enable_static
()
np_data
=
np
.
random
.
random
((
1
,
1
,
28
,
28
)).
astype
(
"float32"
)
np_label
=
np
.
random
.
random
((
1
,
1
)).
astype
(
"int64"
)
path_prefix
=
"self.custom_op_inference/custom_relu"
from
paddle.inference
import
Config
,
create_predictor
for
device
in
self
.
devices
:
predict
=
custom_relu_static_inference
(
self
.
custom_op
,
device
,
np_data
,
np_label
,
path_prefix
)
# load inference model
config
=
Config
(
path_prefix
+
".pdmodel"
,
path_prefix
+
".pdiparams"
)
predictor
=
create_predictor
(
config
)
input_tensor
=
predictor
.
get_input_handle
(
predictor
.
get_input_names
()[
0
]
)
input_tensor
.
reshape
(
np_data
.
shape
)
input_tensor
.
copy_from_cpu
(
np_data
.
copy
())
predictor
.
run
()
output_tensor
=
predictor
.
get_output_handle
(
predictor
.
get_output_names
()[
0
]
)
predict_infer
=
output_tensor
.
copy_to_cpu
()
self
.
assertTrue
(
np
.
isclose
(
predict
,
predict_infer
,
rtol
=
5e-5
).
any
(),
"custom op predict: {},
\n
custom op infer predict: {}"
.
format
(
predict
,
predict_infer
),
)
paddle
.
disable_static
()
def
test_func_double_grad_dynamic
(
self
):
for
device
in
self
.
devices
:
for
dtype
in
self
.
dtypes
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
8
]).
astype
(
dtype
)
out
,
dx_grad
=
custom_relu_double_grad_dynamic
(
self
.
custom_op
,
device
,
dtype
,
x
)
pd_out
,
pd_dx_grad
=
custom_relu_double_grad_dynamic
(
self
.
custom_op
,
device
,
dtype
,
x
,
False
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
np
.
testing
.
assert_array_equal
(
dx_grad
,
pd_dx_grad
,
err_msg
=
'custom op dx grad: {},
\n
paddle api dx grad: {}'
.
format
(
dx_grad
,
pd_dx_grad
),
)
def
test_with_dataloader
(
self
):
for
device
in
self
.
devices
:
paddle
.
set_device
(
device
)
# data loader
transform
=
Compose
(
[
Normalize
(
mean
=
[
127.5
],
std
=
[
127.5
],
data_format
=
'CHW'
)]
)
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
transform
=
transform
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_size
=
64
,
shuffle
=
True
,
drop_last
=
True
,
num_workers
=
0
,
)
for
batch_id
,
(
image
,
_
)
in
enumerate
(
train_loader
()):
out
=
self
.
custom_op
(
image
)
pd_out
=
paddle
.
nn
.
functional
.
relu
(
image
)
np
.
testing
.
assert_array_equal
(
out
,
pd_out
,
err_msg
=
'custom op out: {},
\n
paddle api out: {}'
.
format
(
out
,
pd_out
),
)
if
batch_id
==
5
:
break
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录