Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
684c9197
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
684c9197
编写于
9月 16, 2020
作者:
Z
zhangwen31
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[arm][kernel]test: add elementwise compute test for sub mul div test=develop
上级
266965e2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
51 addition
and
11 deletion
+51
-11
lite/kernels/arm/elementwise_compute_test.cc
lite/kernels/arm/elementwise_compute_test.cc
+51
-11
未找到文件。
lite/kernels/arm/elementwise_compute_test.cc
浏览文件 @
684c9197
...
...
@@ -106,6 +106,20 @@ void elementwise_compute_ref(const operators::ElementwiseParam& param,
}
}
}
}
else
if
(
elt_type
==
"div"
)
{
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
dtype
*
din_ptr
=
x_data
+
offset
;
const
dtype
diny_data
=
y_data
[
j
];
dtype
*
dout_ptr
=
out_data
+
offset
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
din_ptr
/
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
else
if
(
elt_type
==
"max"
)
{
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
...
...
@@ -254,11 +268,14 @@ template void elementwise_imod_compute_ref<int32_t>(
template
void
elementwise_imod_compute_ref
<
int64_t
>(
const
operators
::
ElementwiseParam
&
param
,
const
std
::
string
act_type
);
template
<
typename
T
,
PrecisionType
PType
>
void
elementwise_add_compute
()
{
ElementwiseAddCompute
<
T
,
PType
>
elementwise_add
;
template
<
template
<
class
,
PrecisionType
>
class
ElementWiseComputeTemplate
,
typename
T
,
PrecisionType
PType
>
void
do_elementwise_compute
(
const
char
*
op_type_str
)
{
ElementWiseComputeTemplate
<
T
,
PType
>
elementwise_add
;
operators
::
ElementwiseParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
unsigned
int
rand_seed
=
1
;
#if 1
for
(
auto
n
:
{
1
,
3
,
4
})
{
...
...
@@ -311,10 +328,12 @@ void elementwise_add_compute() {
T
*
output_data
=
output
.
mutable_data
<
T
>
();
T
*
output_ref_data
=
output_ref
.
mutable_data
<
T
>
();
for
(
int
i
=
0
;
i
<
x_dim
.
production
();
i
++
)
{
x_data
[
i
]
=
i
;
x_data
[
i
]
=
1.0
*
rand_r
(
&
rand_seed
)
*
rand_r
(
&
rand_seed
)
/
(
rand_r
(
&
rand_seed
)
+
1
);
}
for
(
int
i
=
0
;
i
<
y_dim
.
production
();
i
++
)
{
y_data
[
i
]
=
i
;
y_data
[
i
]
=
1.0
*
rand_r
(
&
rand_seed
)
*
rand_r
(
&
rand_seed
)
/
(
rand_r
(
&
rand_seed
)
+
1
);
}
param
.
X
=
&
x
;
param
.
Y
=
&
y
;
...
...
@@ -323,15 +342,15 @@ void elementwise_add_compute() {
elementwise_add
.
SetParam
(
param
);
elementwise_add
.
Run
();
param
.
Out
=
&
output_ref
;
elementwise_compute_ref
<
T
>
(
param
,
"add"
,
""
);
elementwise_compute_ref
<
T
>
(
param
,
op_type_str
,
""
);
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPEC
T_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-5
)
ASSER
T_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-5
)
<<
"Value differ at index "
<<
i
;
}
}
else
{
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPEC
T_EQ
(
output_data
[
i
],
output_ref_data
[
i
])
ASSER
T_EQ
(
output_data
[
i
],
output_ref_data
[
i
])
<<
"Value differ at index "
<<
i
;
}
}
...
...
@@ -344,21 +363,42 @@ void elementwise_add_compute() {
}
TEST
(
elementwise_add
,
compute_fp32
)
{
elementwise_add_compute
<
float
,
PRECISION
(
kFloat
)
>
();
do_elementwise_compute
<
ElementwiseAddCompute
,
float
,
PRECISION
(
kFloat
)
>
(
"add"
);
do_elementwise_compute
<
ElementwiseSubCompute
,
float
,
PRECISION
(
kFloat
)
>
(
"sub"
);
do_elementwise_compute
<
ElementwiseMulCompute
,
float
,
PRECISION
(
kFloat
)
>
(
"mul"
);
do_elementwise_compute
<
ElementwiseDivCompute
,
float
,
PRECISION
(
kFloat
)
>
(
"div"
);
if
(
::
testing
::
Test
::
HasFailure
())
{
FAIL
();
}
}
TEST
(
elementwise_add
,
compute_i32
)
{
elementwise_add_compute
<
int32_t
,
PRECISION
(
kInt32
)
>
();
do_elementwise_compute
<
ElementwiseAddCompute
,
int32_t
,
PRECISION
(
kInt32
)
>
(
"add"
);
do_elementwise_compute
<
ElementwiseSubCompute
,
int32_t
,
PRECISION
(
kInt32
)
>
(
"sub"
);
do_elementwise_compute
<
ElementwiseMulCompute
,
int32_t
,
PRECISION
(
kInt32
)
>
(
"mul"
);
do_elementwise_compute
<
ElementwiseDivCompute
,
int32_t
,
PRECISION
(
kInt32
)
>
(
"div"
);
if
(
::
testing
::
Test
::
HasFailure
())
{
FAIL
();
}
}
TEST
(
elementwise_add
,
compute_i64
)
{
elementwise_add_compute
<
int64_t
,
PRECISION
(
kInt64
)
>
();
do_elementwise_compute
<
ElementwiseAddCompute
,
int64_t
,
PRECISION
(
kInt64
)
>
(
"add"
);
do_elementwise_compute
<
ElementwiseSubCompute
,
int64_t
,
PRECISION
(
kInt64
)
>
(
"sub"
);
do_elementwise_compute
<
ElementwiseMulCompute
,
int64_t
,
PRECISION
(
kInt64
)
>
(
"mul"
);
do_elementwise_compute
<
ElementwiseDivCompute
,
int64_t
,
PRECISION
(
kInt64
)
>
(
"div"
);
if
(
::
testing
::
Test
::
HasFailure
())
{
FAIL
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录