Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
af89b659
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af89b659
编写于
6月 18, 2019
作者:
H
hong19860320
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add arm kernel for fusion_elementwise_add_activation op
test=develop
上级
ce6c24e6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
525 addition
and
8 deletion
+525
-8
paddle/fluid/lite/arm/math/elementwise.cc
paddle/fluid/lite/arm/math/elementwise.cc
+131
-3
paddle/fluid/lite/arm/math/elementwise.h
paddle/fluid/lite/arm/math/elementwise.h
+9
-2
paddle/fluid/lite/kernels/arm/CMakeLists.txt
paddle/fluid/lite/kernels/arm/CMakeLists.txt
+3
-3
paddle/fluid/lite/kernels/arm/elementwise_compute.cc
paddle/fluid/lite/kernels/arm/elementwise_compute.cc
+111
-0
paddle/fluid/lite/kernels/arm/elementwise_compute.h
paddle/fluid/lite/kernels/arm/elementwise_compute.h
+8
-0
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
+263
-0
未找到文件。
paddle/fluid/lite/arm/math/elementwise.cc
浏览文件 @
af89b659
...
...
@@ -65,9 +65,61 @@ void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
}
template
<
>
void
elementwise_add_axis
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
void
elementwise_add_relu
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
dinx_ptr
=
dinx
+
(
i
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
dinx0
=
vld1q_f32
(
dinx_ptr
);
float32x4_t
dinx1
=
vld1q_f32
(
dinx_ptr
+
4
);
float32x4_t
dinx2
=
vld1q_f32
(
dinx_ptr
+
8
);
float32x4_t
dinx3
=
vld1q_f32
(
dinx_ptr
+
12
);
float32x4_t
diny0
=
vld1q_f32
(
diny_ptr
);
float32x4_t
diny1
=
vld1q_f32
(
diny_ptr
+
4
);
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
dinx0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx3
=
vaddq_f32
(
dinx3
,
diny3
);
// relu
dinx0
=
vmaxq_f32
(
dinx0
,
vzero
);
dinx1
=
vmaxq_f32
(
dinx1
,
vzero
);
dinx2
=
vmaxq_f32
(
dinx2
,
vzero
);
dinx3
=
vmaxq_f32
(
dinx3
,
vzero
);
vst1q_f32
(
dout_ptr
,
dinx0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
float
tmp
=
*
dinx_ptr
+
*
diny_ptr
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
dinx_ptr
++
;
diny_ptr
++
;
}
}
}
template
<
>
void
elementwise_add_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
...
...
@@ -127,6 +179,82 @@ void elementwise_add_axis<float>(const float* dinx, const float* diny,
}
}
template
<
>
void
elementwise_add_relu_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
din2
=
vmaxq_f32
(
din2
,
vzero
);
din3
=
vmaxq_f32
(
din3
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
float
tmp
=
*
din_ptr
+
diny_data
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
paddle/fluid/lite/arm/math/elementwise.h
浏览文件 @
af89b659
...
...
@@ -23,8 +23,15 @@ template <typename T>
void
elementwise_add
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_axis
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
void
elementwise_add_relu
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
template
<
typename
T
>
void
elementwise_add_relu_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
}
// namespace math
}
// namespace arm
...
...
paddle/fluid/lite/kernels/arm/CMakeLists.txt
浏览文件 @
af89b659
...
...
@@ -11,7 +11,7 @@ cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math
cc_library
(
softmax_compute_arm SRCS softmax_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
conv_compute_arm SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
elementwise_
add_compute_arm SRCS elementwise_add
_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
elementwise_
compute_arm SRCS elementwise
_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
pool_compute_arm SRCS pool_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
split_compute_arm SRCS split_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
cc_library
(
concat_compute_arm SRCS concat_compute.cc DEPS
${
lite_kernel_deps
}
math_arm
)
...
...
@@ -24,7 +24,7 @@ lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_comput
lite_cc_test
(
test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm
)
lite_cc_test
(
test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm
)
lite_cc_test
(
test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm
)
lite_cc_test
(
test_elementwise_
add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add
_compute_arm
)
lite_cc_test
(
test_elementwise_
compute_arm SRCS elementwise_compute_test.cc DEPS elementwise
_compute_arm
)
lite_cc_test
(
test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm
)
lite_cc_test
(
test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm
)
lite_cc_test
(
test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm
)
...
...
@@ -40,7 +40,7 @@ set(arm_kernels
softmax_compute_arm
conv_compute_arm
batch_norm_compute_arm
elementwise_
add_
compute_arm
elementwise_compute_arm
pool_compute_arm
split_compute_arm
concat_compute_arm
...
...
paddle/fluid/lite/kernels/arm/elementwise_
add_
compute.cc
→
paddle/fluid/lite/kernels/arm/elementwise_compute.cc
浏览文件 @
af89b659
...
...
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h"
#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h"
#include <string>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
...
...
@@ -20,6 +21,30 @@ namespace lite {
namespace
kernels
{
namespace
arm
{
inline
bool
is_broadcast
(
const
DDim
&
x_dims
,
const
DDim
&
y_dims
,
int
axis
,
int
*
pre
,
int
*
n
,
int
*
post
)
{
if
(
axis
<
0
)
{
axis
=
x_dims
.
size
()
-
y_dims
.
size
();
}
if
(
x_dims
.
size
()
==
y_dims
.
size
())
{
return
false
;
}
*
pre
=
1
;
*
n
=
1
;
*
post
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
CHECK_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
])
<<
"Broadcast dimension mismatch."
;
(
*
n
)
*=
y_dims
[
i
];
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
return
true
;
}
void
ElementwiseAddCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
ElementwiseParam
>
();
const
float
*
x_data
=
param
.
X
->
data
<
float
>
();
...
...
@@ -28,27 +53,40 @@ void ElementwiseAddCompute::Run() {
int
axis
=
param
.
axis
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
if
(
axis
<
0
)
{
axis
=
x_dims
.
size
()
-
y_dims
.
size
();
}
if
(
x_dims
.
size
()
==
y_dims
.
size
())
{
int
pre
,
n
,
post
;
if
(
is_broadcast
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
))
{
lite
::
arm
::
math
::
elementwise_add_broadcast
(
x_data
,
y_data
,
out_data
,
pre
,
n
,
post
);
}
else
{
lite
::
arm
::
math
::
elementwise_add
(
x_data
,
y_data
,
out_data
,
x_dims
.
production
());
}
else
{
int
batch
=
1
;
int
channels
=
1
;
int
num
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
batch
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
channels
*=
y_dims
[
i
];
}
}
void
ElementwiseAddActivationCompute
::
Run
()
{
auto
&
param
=
Param
<
operators
::
FusionElementwiseActivationParam
>
();
const
float
*
x_data
=
param
.
X
->
data
<
float
>
();
const
float
*
y_data
=
param
.
Y
->
data
<
float
>
();
float
*
out_data
=
param
.
Out
->
mutable_data
<
float
>
();
int
axis
=
param
.
axis
;
std
::
string
act_type
=
param
.
act_type
;
auto
x_dims
=
param
.
X
->
dims
();
auto
y_dims
=
param
.
Y
->
dims
();
int
pre
,
n
,
post
;
if
(
is_broadcast
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
))
{
if
(
act_type
==
"relu"
)
{
lite
::
arm
::
math
::
elementwise_add_relu_broadcast
(
x_data
,
y_data
,
out_data
,
pre
,
n
,
post
);
}
else
{
LOG
(
FATAL
)
<<
"unsupported Activation type: "
<<
act_type
;
}
for
(
int
i
=
y_dims
.
size
()
+
axis
;
i
<
x_dims
.
size
();
++
i
)
{
num
*=
x_dims
[
i
];
}
else
{
if
(
act_type
==
"relu"
)
{
lite
::
arm
::
math
::
elementwise_add_relu
(
x_data
,
y_data
,
out_data
,
x_dims
.
production
());
}
else
{
LOG
(
FATAL
)
<<
"unsupported Activation type: "
<<
act_type
;
}
lite
::
arm
::
math
::
elementwise_add_axis
(
x_data
,
y_data
,
out_data
,
batch
,
channels
,
num
);
}
}
...
...
@@ -63,3 +101,11 @@ REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW,
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
fusion_elementwise_add_activation
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ElementwiseAddActivationCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
paddle/fluid/lite/kernels/arm/elementwise_
add_
compute.h
→
paddle/fluid/lite/kernels/arm/elementwise_compute.h
浏览文件 @
af89b659
...
...
@@ -30,6 +30,14 @@ class ElementwiseAddCompute
virtual
~
ElementwiseAddCompute
()
=
default
;
};
class
ElementwiseAddActivationCompute
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
;
virtual
~
ElementwiseAddActivationCompute
()
=
default
;
};
}
// namespace arm
}
// namespace kernels
}
// namespace lite
...
...
paddle/fluid/lite/kernels/arm/elementwise_
add_
compute_test.cc
→
paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc
浏览文件 @
af89b659
...
...
@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/elementwise_
add_
compute.h"
#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
...
...
@@ -37,7 +38,9 @@ TEST(elementwise_add_arm, init) {
}
template
<
typename
dtype
>
void
elementwise_add_compute_ref
(
const
operators
::
ElementwiseParam
&
param
)
{
void
elementwise_compute_ref
(
const
operators
::
ElementwiseParam
&
param
,
const
std
::
string
elt_type
,
const
std
::
string
act_type
)
{
const
dtype
*
x_data
=
param
.
X
->
data
<
const
dtype
>
();
const
dtype
*
y_data
=
param
.
Y
->
data
<
const
dtype
>
();
dtype
*
out_data
=
param
.
Out
->
mutable_data
<
dtype
>
();
...
...
@@ -59,17 +62,52 @@ void elementwise_add_compute_ref(const operators::ElementwiseParam& param) {
for
(
int
i
=
y_dims
.
size
()
+
axis
;
i
<
x_dims
.
size
();
++
i
)
{
num
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
dtype
*
din_ptr
=
x_data
+
offset
;
const
dtype
diny_data
=
y_data
[
j
];
dtype
*
dout_ptr
=
out_data
+
offset
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
// do elementwise add/sub/max...
if
(
elt_type
==
"add"
)
{
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
dtype
*
din_ptr
=
x_data
+
offset
;
const
dtype
diny_data
=
y_data
[
j
];
dtype
*
dout_ptr
=
out_data
+
offset
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
else
if
(
elt_type
==
"sub"
)
{
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
dtype
*
din_ptr
=
x_data
+
offset
;
const
dtype
diny_data
=
y_data
[
j
];
dtype
*
dout_ptr
=
out_data
+
offset
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
din_ptr
-
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported Elementwise type: "
<<
elt_type
;
}
// do activation relu/sigmod...
if
(
act_type
.
size
()
>
0
)
{
if
(
act_type
==
"relu"
)
{
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
dtype
*
dout_ptr
=
out_data
+
(
i
*
channels
+
j
)
*
num
;
for
(
int
k
=
0
;
k
<
num
;
++
k
)
{
*
dout_ptr
=
*
dout_ptr
>
0.0
f
?
*
dout_ptr
:
0.0
f
;
dout_ptr
++
;
}
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported Activation type: "
<<
elt_type
;
}
}
}
...
...
@@ -123,7 +161,7 @@ TEST(elementwise_add, compute) {
elementwise_add
.
SetParam
(
param
);
elementwise_add
.
Run
();
param
.
Out
=
&
output_ref
;
elementwise_
add_compute_ref
<
float
>
(
param
);
elementwise_
compute_ref
<
float
>
(
param
,
"add"
,
""
);
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-5
);
}
...
...
@@ -135,9 +173,91 @@ TEST(elementwise_add, compute) {
}
}
TEST
(
fusion_elementwise_add_activation_arm
,
retrive_op
)
{
auto
fusion_elementwise_add_activation
=
KernelRegistry
::
Global
().
Create
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
(
"fusion_elementwise_add_activation"
);
ASSERT_FALSE
(
fusion_elementwise_add_activation
.
empty
());
ASSERT_TRUE
(
fusion_elementwise_add_activation
.
front
());
}
TEST
(
fusion_elementwise_add_activation_arm
,
init
)
{
ElementwiseAddActivationCompute
fusion_elementwise_add_activation
;
ASSERT_EQ
(
fusion_elementwise_add_activation
.
precision
(),
PRECISION
(
kFloat
));
ASSERT_EQ
(
fusion_elementwise_add_activation
.
target
(),
TARGET
(
kARM
));
}
TEST
(
fusion_elementwise_add_activation_arm
,
compute
)
{
ElementwiseAddActivationCompute
fusion_elementwise_add_activation
;
operators
::
FusionElementwiseActivationParam
param
;
lite
::
Tensor
x
,
y
,
output
,
output_ref
;
for
(
auto
act_type
:
{
"relu"
})
{
for
(
auto
n
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
c
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
h
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
w
:
{
1
,
3
,
4
,
11
})
{
for
(
auto
axis
:
{
-
1
,
0
,
1
,
2
,
3
})
{
for
(
auto
yd
:
{
std
::
vector
<
int64_t
>
({
n
}),
std
::
vector
<
int64_t
>
({
c
}),
std
::
vector
<
int64_t
>
({
h
}),
std
::
vector
<
int64_t
>
({
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
}),
std
::
vector
<
int64_t
>
({
c
,
h
}),
std
::
vector
<
int64_t
>
({
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
}),
std
::
vector
<
int64_t
>
({
c
,
h
,
w
}),
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
})})
{
auto
x_dim
=
DDim
(
std
::
vector
<
int64_t
>
({
n
,
c
,
h
,
w
}));
auto
y_dim
=
DDim
(
yd
);
int
axis_t
=
axis
<
0
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
;
if
(
axis_t
+
y_dim
.
size
()
>
4
)
continue
;
bool
flag
=
false
;
for
(
int
i
=
0
;
i
<
y_dim
.
size
();
i
++
)
{
if
(
x_dim
[
i
+
axis_t
]
!=
y_dim
[
i
])
flag
=
true
;
}
if
(
flag
)
continue
;
x
.
Resize
(
x_dim
);
y
.
Resize
(
y_dim
);
output
.
Resize
(
x_dim
);
output_ref
.
Resize
(
x_dim
);
auto
*
x_data
=
x
.
mutable_data
<
float
>
();
auto
*
y_data
=
y
.
mutable_data
<
float
>
();
auto
*
output_data
=
output
.
mutable_data
<
float
>
();
auto
*
output_ref_data
=
output_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_dim
.
production
();
i
++
)
{
float
sign
=
i
%
3
==
0
?
-
1.0
f
:
1.0
f
;
x_data
[
i
]
=
i
*
sign
;
}
for
(
int
i
=
0
;
i
<
y_dim
.
production
();
i
++
)
{
float
sign
=
i
%
2
==
0
?
0.5
f
:
-
0.5
f
;
y_data
[
i
]
=
i
*
sign
;
}
param
.
X
=
&
x
;
param
.
Y
=
&
y
;
param
.
axis
=
axis
;
param
.
Out
=
&
output
;
param
.
act_type
=
act_type
;
fusion_elementwise_add_activation
.
SetParam
(
param
);
fusion_elementwise_add_activation
.
Run
();
param
.
Out
=
&
output_ref
;
elementwise_compute_ref
<
float
>
(
param
,
"add"
,
act_type
);
for
(
int
i
=
0
;
i
<
output
.
dims
().
production
();
i
++
)
{
EXPECT_NEAR
(
output_data
[
i
],
output_ref_data
[
i
],
1e-5
);
}
}
}
}
}
}
}
}
}
}
// namespace arm
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
USE_LITE_KERNEL
(
elementwise_add
,
kARM
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fusion_elementwise_add_activation
,
kARM
,
kFloat
,
kNCHW
,
def
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录