Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3a631fbb
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a631fbb
编写于
7月 01, 2019
作者:
C
Chunwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
prune
上级
0f9e7057
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
0 addition
and
4995 deletion
+0
-4995
paddle/fluid/lite/arm/math/activation.cc
paddle/fluid/lite/arm/math/activation.cc
+0
-520
paddle/fluid/lite/arm/math/activation.h
paddle/fluid/lite/arm/math/activation.h
+0
-50
paddle/fluid/lite/arm/math/concat.cc
paddle/fluid/lite/arm/math/concat.cc
+0
-59
paddle/fluid/lite/arm/math/concat.h
paddle/fluid/lite/arm/math/concat.h
+0
-34
paddle/fluid/lite/arm/math/dropout.cc
paddle/fluid/lite/arm/math/dropout.cc
+0
-93
paddle/fluid/lite/arm/math/dropout.h
paddle/fluid/lite/arm/math/dropout.h
+0
-32
paddle/fluid/lite/arm/math/elementwise.cc
paddle/fluid/lite/arm/math/elementwise.cc
+0
-261
paddle/fluid/lite/arm/math/elementwise.h
paddle/fluid/lite/arm/math/elementwise.h
+0
-39
paddle/fluid/lite/arm/math/pooling.cc
paddle/fluid/lite/arm/math/pooling.cc
+0
-2859
paddle/fluid/lite/arm/math/pooling.h
paddle/fluid/lite/arm/math/pooling.h
+0
-73
paddle/fluid/lite/arm/math/scale.cc
paddle/fluid/lite/arm/math/scale.cc
+0
-169
paddle/fluid/lite/arm/math/scale.h
paddle/fluid/lite/arm/math/scale.h
+0
-36
paddle/fluid/lite/arm/math/softmax.cc
paddle/fluid/lite/arm/math/softmax.cc
+0
-601
paddle/fluid/lite/arm/math/softmax.h
paddle/fluid/lite/arm/math/softmax.h
+0
-52
paddle/fluid/lite/arm/math/split.cc
paddle/fluid/lite/arm/math/split.cc
+0
-82
paddle/fluid/lite/arm/math/split.h
paddle/fluid/lite/arm/math/split.h
+0
-35
未找到文件。
paddle/fluid/lite/arm/math/activation.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/activation.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
act_relu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
vr0
=
vmaxq_f32
(
vr0
,
vzero
);
vr1
=
vmaxq_f32
(
vr1
,
vzero
);
vr2
=
vmaxq_f32
(
vr2
,
vzero
);
vr3
=
vmaxq_f32
(
vr3
,
vzero
);
vst1q_f32
(
ptr_out_thread
,
vr0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vmax.f32 q8, q0, %q[vzero] @ relu
\n
"
"vmax.f32 q9, q1, %q[vzero] @ relu
\n
"
"vmax.f32 q10, q2, %q[vzero] @ relu
\n
"
"vmax.f32 q11, q3, %q[vzero] @ relu
\n
"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
0.
f
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
0.
f
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_relu_neg
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
negative_slope
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
valpha
=
vdupq_n_f32
(
negative_slope
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
uint32x4_t
vm0
=
vcgeq_f32
(
vr0
,
vzero
);
uint32x4_t
vm1
=
vcgeq_f32
(
vr1
,
vzero
);
uint32x4_t
vm2
=
vcgeq_f32
(
vr2
,
vzero
);
uint32x4_t
vm3
=
vcgeq_f32
(
vr3
,
vzero
);
float32x4_t
vn0
=
vmulq_f32
(
vr0
,
valpha
);
float32x4_t
vn1
=
vmulq_f32
(
vr1
,
valpha
);
float32x4_t
vn2
=
vmulq_f32
(
vr2
,
valpha
);
float32x4_t
vn3
=
vmulq_f32
(
vr3
,
valpha
);
float32x4_t
vo0
=
vbslq_f32
(
vm0
,
vr0
,
vn0
);
float32x4_t
vo1
=
vbslq_f32
(
vm1
,
vr1
,
vn1
);
float32x4_t
vo2
=
vbslq_f32
(
vm2
,
vr2
,
vn2
);
float32x4_t
vo3
=
vbslq_f32
(
vm3
,
vr3
,
vn3
);
vst1q_f32
(
ptr_out_thread
,
vo0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vcge.f32 q8, q0, %q[vzero] @ get mask
\n
"
"vcge.f32 q9, q1, %q[vzero] @ get mask
\n
"
"vcge.f32 q10, q2, %q[vzero] @ get mask
\n
"
"vcge.f32 q11, q3, %q[vzero] @ get mask
\n
"
"vmul.f32 q4, q0, %q[valpha] @ get neg data
\n
"
"vmul.f32 q5, q1, %q[valpha] @ get neg data
\n
"
"vmul.f32 q6, q2, %q[valpha] @ get neg data
\n
"
"vmul.f32 q7, q3, %q[valpha] @ get neg data
\n
"
"vbit q4, q0, q8 @ bitsel, insert q0 to q4, "
"if q8 is 1
\n
"
"vbit q5, q1, q9 @ bitsel, insert q1 to q5, "
"if q9 is 1
\n
"
"vbit q6, q2, q10 @ bitsel, insert q2 to q6, "
"if q10 is 1
\n
"
"vbit q7, q3, q11 @ bitsel, insert q3 to q7, "
"if q11 is 1
\n
"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
valpha
]
"w"
(
valpha
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
ptr_in_thread
[
0
]
*
negative_slope
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
in_ptr_remain
[
0
]
*
negative_slope
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_clipped_relu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
coef
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vclip
=
vdupq_n_f32
(
coef
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vt0
=
vmaxq_f32
(
vr0
,
vzero
);
float32x4_t
vt1
=
vmaxq_f32
(
vr1
,
vzero
);
float32x4_t
vt2
=
vmaxq_f32
(
vr2
,
vzero
);
float32x4_t
vt3
=
vmaxq_f32
(
vr3
,
vzero
);
float32x4_t
vo0
=
vminq_f32
(
vt0
,
vclip
);
float32x4_t
vo1
=
vminq_f32
(
vt1
,
vclip
);
float32x4_t
vo2
=
vminq_f32
(
vt2
,
vclip
);
float32x4_t
vo3
=
vminq_f32
(
vt3
,
vclip
);
vst1q_f32
(
ptr_out_thread
,
vo0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vmax.f32 q8, q0, %q[vzero] @ relu
\n
"
"vmax.f32 q9, q1, %q[vzero] @ relu
\n
"
"vmax.f32 q10, q2, %q[vzero] @ relu
\n
"
"vmax.f32 q11, q3, %q[vzero] @ relu
\n
"
"vmin.f32 q4, q8, %q[vclip] @ clip relu
\n
"
"vmin.f32 q5, q9, %q[vclip] @ clip relu
\n
"
"vmin.f32 q6, q10, %q[vclip] @ clip relu
\n
"
"vmin.f32 q7, q11, %q[vclip] @ clip relu
\n
"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
vclip
]
"w"
(
vclip
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
0.
f
;
ptr_out_thread
[
0
]
=
ptr_out_thread
[
0
]
<
coef
?
ptr_out_thread
[
0
]
:
coef
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
0.
f
;
out_ptr_remain
[
0
]
=
out_ptr_remain
[
0
]
<
coef
?
out_ptr_remain
[
0
]
:
coef
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_prelu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_size
,
int
channel_size
,
int
inner_size
,
bool
channel_shared
,
float
*
channel_slope
,
int
threads
)
{
int
stride_size
=
inner_size
*
channel_size
;
int
cnt
=
inner_size
>>
4
;
int
remain
=
inner_size
&
15
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
outer_size
;
n
++
)
{
const
float
*
data_in_batch
=
din
+
n
*
stride_size
;
float
*
data_out_batch
=
dout
+
n
*
stride_size
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
channel_size
;
c
++
)
{
const
float
*
data_in_c
=
data_in_batch
+
c
*
inner_size
;
float
*
data_out_c
=
data_out_batch
+
c
*
inner_size
;
float
slope
=
channel_shared
?
channel_slope
[
0
]
:
channel_slope
[
c
];
float32x4_t
vslope
=
vdupq_n_f32
(
slope
);
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
float32x4_t
vr0
=
vld1q_f32
(
data_in_c
);
float32x4_t
vr1
=
vld1q_f32
(
data_in_c
+
4
);
float32x4_t
vr2
=
vld1q_f32
(
data_in_c
+
8
);
float32x4_t
vr3
=
vld1q_f32
(
data_in_c
+
12
);
uint32x4_t
vm0
=
vcltq_f32
(
vr0
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm1
=
vcltq_f32
(
vr1
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm2
=
vcltq_f32
(
vr2
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm3
=
vcltq_f32
(
vr3
,
vzero
);
// vr0 <= vzero
float32x4_t
vo0
=
vmulq_f32
(
vr0
,
vslope
);
// vr0 * vslope
float32x4_t
vo1
=
vmulq_f32
(
vr1
,
vslope
);
// vr0 * vslope
float32x4_t
vo2
=
vmulq_f32
(
vr2
,
vslope
);
// vr0 * vslope
float32x4_t
vo3
=
vmulq_f32
(
vr3
,
vslope
);
// vr0 * vslope
float32x4_t
vos0
=
vbslq_f32
(
vm0
,
vo0
,
vr0
);
float32x4_t
vos1
=
vbslq_f32
(
vm1
,
vo1
,
vr1
);
float32x4_t
vos2
=
vbslq_f32
(
vm2
,
vo2
,
vr2
);
float32x4_t
vos3
=
vbslq_f32
(
vm3
,
vo3
,
vr3
);
vst1q_f32
(
data_out_c
,
vos0
);
vst1q_f32
(
data_out_c
+
4
,
vos1
);
vst1q_f32
(
data_out_c
+
8
,
vos2
);
vst1q_f32
(
data_out_c
+
12
,
vos3
);
data_in_c
+=
16
;
data_out_c
+=
16
;
}
#else
int
cnt_loop
=
cnt
;
if
(
cnt_loop
>
0
)
{
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load "
"input to q0, q1
\n
"
"pld [%[ptr_in]] @ preload
\n
"
"pld [%[ptr_in], #64] @ preload
\n
"
"pld [%[ptr_in], #128] @ preload
\n
"
"pld [%[ptr_in], #192] @ preload
\n
"
"1: @main loop
\n
"
"vld1.32 {d4-d7}, [%[ptr_in]]! @ load input to "
"q2, q3
\n
"
"vclt.f32 q8, q0, %q[vzero] @vcle q0 <= vzero
\n
"
"vclt.f32 q9, q1, %q[vzero] @vcle q1 <= vzero
\n
"
"vmul.f32 q10, q0, %q[vslope] @vmul q0 * vslope
\n
"
"vmul.f32 q11, q1, %q[vslope] @vmul q1 * vslope
\n
"
"vclt.f32 q12, q2, %q[vzero] @vcle q2 <= vzero
\n
"
"vclt.f32 q13, q3, %q[vzero] @vcle q3 <= vzero
\n
"
"vmul.f32 q14, q2, %q[vslope] @vmul q2 * vslope
\n
"
"vmul.f32 q15, q3, %q[vslope] @vmul q3 * vslope
\n
"
"vbif.32 q10, q0, q8 @vbit q10, q0, q8
\n
"
"vbif.32 q11, q1, q9 @vbit q11, q1, q9
\n
"
"vbif.32 q14, q2, q12 @vbit q14, q2, "
"q12
\n
"
"vbif.32 q15, q3, q13 @vbit q15, q3, "
"q13
\n
"
"subs %[cnt], #1 @subs nn, 1
\n
"
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load input to "
"q0, q1
\n
"
"vst1.f32 {d20-d23}, [%[dout]]! @store data
\n
"
"vst1.f32 {d28-d31}, [%[dout]]! @store data
\n
"
"bne 1b @bne nn
\n
"
"sub %[ptr_in], #32 @ ptr-32
\n
"
:
[
ptr_in
]
"+r"
(
data_in_c
),
[
cnt
]
"+r"
(
cnt_loop
),
[
dout
]
"+r"
(
data_out_c
)
:
[
vzero
]
"w"
(
vzero
),
[
vslope
]
"w"
(
vslope
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
#endif // __aarch64__
for
(
int
i
=
remain
;
i
>
0
;
i
--
)
{
*
(
data_out_c
++
)
=
data_in_c
[
0
]
>
0.
f
?
data_in_c
[
0
]
:
data_in_c
[
0
]
*
slope
;
data_in_c
++
;
}
}
}
}
template
<
>
void
act_sigmoid
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
float32x4_t
exp_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
recip
=
vdupq_n_f32
(
0.0
f
);
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
exp_vec
=
exp_ps
(
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
)));
exp_vec
=
vaddq_f32
(
exp_vec
,
vdupq_n_f32
(
1.0
f
));
recip
=
vrecpeq_f32
(
exp_vec
);
recip
=
vmulq_f32
(
vrecpsq_f32
(
exp_vec
,
recip
),
recip
);
recip
=
vmulq_f32
(
vrecpsq_f32
(
exp_vec
,
recip
),
recip
);
vst1q_f32
(
ptr_out_thread
,
recip
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
1.
f
/
(
1
+
expf
(
-
ptr_in_thread
[
0
]));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
1.
f
/
(
1
+
expf
(
-
ptr_in
[
0
]));
ptr_in
++
;
ptr_out
++
;
}
}
// tanh : (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
>
void
act_tanh
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
float32x4_t
exp_plus_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_minus_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_sum_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_diff_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
recip
=
vdupq_n_f32
(
0.0
f
);
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
exp_plus_vec
=
exp_ps
(
vld1q_f32
(
ptr_in_thread
));
exp_minus_vec
=
exp_ps
(
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
)));
exp_sum_vec
=
vaddq_f32
(
exp_plus_vec
,
exp_minus_vec
);
exp_diff_vec
=
vsubq_f32
(
exp_plus_vec
,
exp_minus_vec
);
recip
=
div_ps
(
exp_diff_vec
,
exp_sum_vec
);
vst1q_f32
(
ptr_out_thread
,
recip
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
(
expf
(
ptr_in_thread
[
0
])
-
expf
(
-
ptr_in_thread
[
0
]))
/
(
expf
(
ptr_in_thread
[
0
])
+
expf
(
-
ptr_in_thread
[
0
]));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
(
expf
(
ptr_in
[
0
])
-
expf
(
-
ptr_in
[
0
]))
/
(
expf
(
ptr_in
[
0
])
+
expf
(
-
ptr_in
[
0
]));
ptr_in
++
;
ptr_out
++
;
}
}
// swish: x /(1 + exp(-(b * x)))
template
<
>
void
act_swish
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
coef
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
const
float
beta
=
coef
;
float32x4_t
vbeta
=
vdupq_n_f32
(
beta
);
float32x4_t
vone
=
vdupq_n_f32
(
1.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
float32x4_t
va
=
vld1q_f32
(
ptr_in_thread
);
// x
float32x4_t
vb
=
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
));
// -x
float32x4_t
vsum
=
vmulq_f32
(
vb
,
vbeta
);
vsum
=
exp_ps
(
vsum
);
float32x4_t
vc
=
vaddq_f32
(
vone
,
vsum
);
float32x4_t
vrst
=
div_ps
(
va
,
vc
);
vst1q_f32
(
ptr_out_thread
,
vrst
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
/
(
1.0
+
expf
(
-
ptr_in_thread
[
0
]
*
beta
));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
ptr_in
[
0
]
/
(
1.0
+
expf
(
-
ptr_in
[
0
]
*
beta
));
ptr_in
++
;
ptr_out
++
;
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/activation.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
act_relu
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_relu_neg
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
negative_slope
,
int
threads
);
template
<
typename
T
>
void
act_clipped_relu
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
coef
,
int
threads
);
template
<
typename
T
>
void
act_prelu
(
const
T
*
din
,
T
*
dout
,
int
outer_size
,
int
channel_size
,
int
inner_size
,
bool
channel_shared
,
float
*
channel_slope
,
int
threads
);
template
<
typename
T
>
void
act_sigmoid
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_tanh
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_swish
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
coef
,
int
threads
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/concat.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/concat.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
const
int
axis
,
lite
::
Tensor
*
output
)
{
size_t
num
=
input
.
size
();
int
rows
=
1
;
auto
dim_0
=
input
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
out_rows
=
rows
,
out_cols
=
0
;
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
]
->
numel
()
/
rows
;
out_cols
+=
t_cols
;
input_cols
[
i
]
=
t_cols
;
}
// computation
for
(
int
k
=
0
;
k
<
out_rows
;
++
k
)
{
float
*
dst_ptr
=
output
->
mutable_data
<
float
>
()
+
k
*
out_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
input_cols
[
j
];
const
float
*
src_prt
=
input
[
j
]
->
data
<
float
>
()
+
k
*
col_len
;
std
::
memcpy
(
dst_ptr
+
col_idx
,
src_prt
,
sizeof
(
float
)
*
col_len
);
col_idx
+=
col_len
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/concat.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
const
int
axis
,
lite
::
Tensor
*
output
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/dropout.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/dropout.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
dropout_down
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
,
float
prob
)
{
const
float
scale
=
1.0
f
-
prob
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vmul0
=
vmulq_f32
(
din0
,
vscale
);
float32x4_t
vmul1
=
vmulq_f32
(
din1
,
vscale
);
float32x4_t
vmul2
=
vmulq_f32
(
din2
,
vscale
);
float32x4_t
vmul3
=
vmulq_f32
(
din3
,
vscale
);
vst1q_f32
(
dout_ptr
,
vmul0
);
vst1q_f32
(
dout_ptr
+
4
,
vmul1
);
vst1q_f32
(
dout_ptr
+
8
,
vmul2
);
vst1q_f32
(
dout_ptr
+
12
,
vmul3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
dropout_up
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/dropout.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
dropout_down
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
prob
);
template
<
typename
T
>
void
dropout_up
(
const
T
*
din
,
T
*
dout
,
int
num
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/elementwise.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/elementwise.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
elementwise_add
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
dinx_ptr
=
dinx
+
(
i
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
dinx0
=
vld1q_f32
(
dinx_ptr
);
float32x4_t
dinx1
=
vld1q_f32
(
dinx_ptr
+
4
);
float32x4_t
dinx2
=
vld1q_f32
(
dinx_ptr
+
8
);
float32x4_t
dinx3
=
vld1q_f32
(
dinx_ptr
+
12
);
float32x4_t
diny0
=
vld1q_f32
(
diny_ptr
);
float32x4_t
diny1
=
vld1q_f32
(
diny_ptr
+
4
);
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
dinx0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx3
=
vaddq_f32
(
dinx3
,
diny3
);
vst1q_f32
(
dout_ptr
,
dinx0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
dinx_ptr
+
*
diny_ptr
;
dout_ptr
++
;
dinx_ptr
++
;
diny_ptr
++
;
}
}
}
template
<
>
void
elementwise_add_relu
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
dinx_ptr
=
dinx
+
(
i
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
dinx0
=
vld1q_f32
(
dinx_ptr
);
float32x4_t
dinx1
=
vld1q_f32
(
dinx_ptr
+
4
);
float32x4_t
dinx2
=
vld1q_f32
(
dinx_ptr
+
8
);
float32x4_t
dinx3
=
vld1q_f32
(
dinx_ptr
+
12
);
float32x4_t
diny0
=
vld1q_f32
(
diny_ptr
);
float32x4_t
diny1
=
vld1q_f32
(
diny_ptr
+
4
);
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
dinx0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx3
=
vaddq_f32
(
dinx3
,
diny3
);
// relu
dinx0
=
vmaxq_f32
(
dinx0
,
vzero
);
dinx1
=
vmaxq_f32
(
dinx1
,
vzero
);
dinx2
=
vmaxq_f32
(
dinx2
,
vzero
);
dinx3
=
vmaxq_f32
(
dinx3
,
vzero
);
vst1q_f32
(
dout_ptr
,
dinx0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
float
tmp
=
*
dinx_ptr
+
*
diny_ptr
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
dinx_ptr
++
;
diny_ptr
++
;
}
}
}
template
<
>
void
elementwise_add_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
template
<
>
void
elementwise_add_relu_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
din2
=
vmaxq_f32
(
din2
,
vzero
);
din3
=
vmaxq_f32
(
din3
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
float
tmp
=
*
din_ptr
+
diny_data
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/elementwise.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
elementwise_add
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_relu
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
template
<
typename
T
>
void
elementwise_add_relu_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/pooling.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/pooling.h"
#include <algorithm>
#include <limits>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
pooling_basic
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
global_pooling
,
bool
exclusive
,
bool
adaptive
,
bool
ceil_mode
,
bool
use_quantizer
,
const
std
::
string
&
pooling_type
)
{
// no need to pad input tensor, border is zero pad inside this function
int
kernel_h
=
ksize
[
0
];
int
kernel_w
=
ksize
[
1
];
int
stride_h
=
strides
[
0
];
int
stride_w
=
strides
[
1
];
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
1
];
int
size_channel_in
=
win
*
hin
;
int
size_channel_out
=
wout
*
hout
;
if
(
global_pooling
)
{
if
(
pooling_type
==
"max"
)
{
// Pooling_max
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
float
tmp1
=
din_ch
[
0
];
for
(
int
i
=
0
;
i
<
size_channel_in
;
++
i
)
{
float
tmp2
=
din_ch
[
i
];
tmp1
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
}
dout_batch
[
c
]
=
tmp1
;
}
}
}
else
if
(
pooling_type
==
"avg"
)
{
// Pooling_average_include_padding
// Pooling_average_exclude_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
size_channel_in
;
++
i
)
{
sum
+=
din_ch
[
i
];
}
dout_batch
[
c
]
=
sum
/
size_channel_in
;
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported pooling type: "
<<
pooling_type
;
}
}
else
{
if
(
pooling_type
==
"max"
)
{
// Pooling_max
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
tmp1
=
din_ch
[
hstart
*
win
+
wstart
];
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
float
tmp2
=
din_ch
[
h
*
win
+
w
];
tmp1
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
}
}
dout_row
[
j
]
=
tmp1
;
}
dout_row
+=
wout
;
}
}
}
}
else
if
(
pooling_type
==
"avg"
)
{
if
(
exclusive
)
{
// Pooling_average_exclude_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
sum
=
0.
f
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
sum
+=
din_ch
[
h
*
win
+
w
];
}
}
dout_row
[
j
]
=
sum
/
pool_size
;
}
dout_row
+=
wout
;
}
}
}
}
else
{
// Pooling_average_include_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
sum
=
0.
f
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
sum
+=
din_ch
[
h
*
win
+
w
];
}
}
dout_row
[
j
]
=
sum
/
(
kernel_w
*
kernel_h
);
}
dout_row
+=
wout
;
}
}
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported pooling type: "
<<
pooling_type
;
}
}
}
void
pooling_global_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
size_channel_in
=
win
*
hin
;
int
cnt
=
size_channel_in
/
8
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
int
i
=
0
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
float32x4_t
vmax
=
vdupq_n_f32
(
minval
);
#ifdef __aarch64__
for
(;
i
<
cnt
;
i
++
)
{
float32x4_t
vdin1
=
vld1q_f32
(
din_ch
);
vmax
=
vmaxq_f32
(
vdin1
,
vmax
);
float32x4_t
vdin2
=
vld1q_f32
(
din_ch
+
4
);
vmax
=
vmaxq_f32
(
vmax
,
vdin2
);
din_ch
+=
8
;
}
#else
int
cnt_num
=
cnt
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"max_loop: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch
\n
"
"vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch
\n
"
"vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data
\n
"
"vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne max_loop @bne cnt_num
\n
"
:
[
din_ch
]
"+r"
(
din_ch
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vmax
]
"+w"
(
vmax
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
);
}
#endif // __aarch64__
float32x2_t
vmax_tmp
=
vmax_f32
(
vget_low_f32
(
vmax
),
vget_high_f32
(
vmax
));
float
tmp1
=
vget_lane_f32
(
vmax_tmp
,
0
);
float
tmp2
=
vget_lane_f32
(
vmax_tmp
,
1
);
float
max_tmp
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
for
(
i
=
cnt
*
8
;
i
<
size_channel_in
;
++
i
)
{
/* code */
max_tmp
=
max_tmp
>
din_ch
[
0
]
?
max_tmp
:
din_ch
[
0
];
din_ch
++
;
}
dout_batch
[
c
]
=
max_tmp
;
}
}
}
void
pooling_global_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
size_channel_in
=
win
*
hin
;
int
cnt
=
size_channel_in
/
4
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
int
i
=
0
;
float32x4_t
vsum
=
vdupq_n_f32
(
0.0
f
);
#ifdef __aarch64__
for
(;
i
<
cnt
;
i
++
)
{
vsum
=
vaddq_f32
(
vld1q_f32
(
din_ch
),
vsum
);
din_ch
+=
4
;
}
#else
int
cnt_num
=
cnt
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"add_loop: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch
\n
"
"vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne add_loop @bne num
\n
"
:
[
din_ch
]
"+r"
(
din_ch
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vsum
]
"+w"
(
vsum
)
:
:
"cc"
,
"memory"
,
"q0"
);
}
#endif // __aarch64__
float32x2_t
vsum_tmp
=
vadd_f32
(
vget_low_f32
(
vsum
),
vget_high_f32
(
vsum
));
float
sum
=
vget_lane_f32
(
vsum_tmp
,
0
)
+
vget_lane_f32
(
vsum_tmp
,
1
);
for
(
i
=
cnt
*
4
;
i
<
size_channel_in
;
i
++
)
{
sum
+=
din_ch
[
0
];
din_ch
++
;
}
dout_batch
[
c
]
=
sum
/
size_channel_in
;
}
}
}
void
pooling2x2s2_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
2
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
);
int
h_needed
=
(
hout
<<
1
);
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
// int w_unroll_remain = w_even - w_unroll_size;
int
w_in_2
=
win
<<
1
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
dr10
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
dr11
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
dmax1
=
vmaxq_f32
(
dr00
,
dr10
);
float32x4_t
dmax2
=
vmaxq_f32
(
dr01
,
dr11
);
#ifdef __aarch64__
float32x4_t
dmax
=
vpmaxq_f32
(
dmax1
,
dmax2
);
#else
float32x2_t
dmaxl
=
vpmax_f32
(
vget_low_f32
(
dmax1
),
vget_high_f32
(
dmax1
));
float32x2_t
dmaxh
=
vpmax_f32
(
vget_low_f32
(
dmax2
),
vget_high_f32
(
dmax2
));
float32x4_t
dmax
=
vcombine_f32
(
dmaxl
,
dmaxh
);
#endif
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
dmax
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"s2_max_loop: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vmax.f32 q1, q1, q3 @max q1,q1,q2
\n
"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1
\n
"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne s2_max_loop @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
std
::
max
(
std
::
max
(
r0
[
w
],
r0
[
w
+
1
]),
std
::
max
(
r1
[
w
],
r1
[
w
+
1
]));
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
std
::
max
(
r0
[
w
],
r1
[
w
]);
}
r0
+=
w_in_2
;
// << 1;
r1
+=
w_in_2
;
// << 1;
dout_ch
+=
wout
;
}
// process remain row (odd, last row)
for
(;
h
<
h_limit
;
h
++
)
{
// run 0 or 1 time
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
#ifdef __aarch64__
float32x4_t
dmax
=
vpmaxq_f32
(
dr00
,
dr01
);
#else
float32x2_t
dmaxl
=
vpmax_f32
(
vget_low_f32
(
dr00
),
vget_high_f32
(
dr00
));
float32x2_t
dmaxh
=
vpmax_f32
(
vget_low_f32
(
dr01
),
vget_high_f32
(
dr01
));
float32x4_t
dmax
=
vcombine_f32
(
dmaxl
,
dmaxh
);
#endif
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
dmax
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"s2_max_loop1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1
\n
"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne s2_max_loop1 @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
std
::
max
(
r0
[
w
],
r0
[
w
+
1
]);
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
r0
[
w
];
}
}
}
}
}
void
pooling2x2s2_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
2
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
);
int
h_needed
=
(
hout
<<
1
);
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
// int w_unroll_remain = w_even - w_unroll_size;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
4.
f
;
const
float
coef_1
=
exclusive
?
1.
f
:
coef
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_1
=
vdupq_n_f32
(
coef_1
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
dr10
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
dr11
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
dsum1
=
vaddq_f32
(
dr00
,
dr10
);
float32x4_t
dsum2
=
vaddq_f32
(
dr01
,
dr11
);
#ifdef __aarch64__
float32x4_t
dsum
=
vpaddq_f32
(
dsum1
,
dsum2
);
#else
float32x2_t
dsuml
=
vpadd_f32
(
vget_low_f32
(
dsum1
),
vget_high_f32
(
dsum1
));
float32x2_t
dsumh
=
vpadd_f32
(
vget_low_f32
(
dsum2
),
vget_high_f32
(
dsum2
));
float32x4_t
dsum
=
vcombine_f32
(
dsuml
,
dsumh
);
#endif
float32x4_t
res
=
vmulq_f32
(
dsum
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
res
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q2
\n
"
"vadd.f32 q1, q1, q3 @add q1,q1,q2
\n
"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1
\n
"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3
\n
"
"vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
vcoef
]
"+w"
(
vcoef
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"w"
(
vcoef
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r0
[
w
+
1
]
+
r1
[
w
]
+
r1
[
w
+
1
])
*
coef
;
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r1
[
w
])
*
coef_2
;
}
r0
+=
w_in_2
;
// << 1;
r1
+=
w_in_2
;
// << 1;
dout_ch
+=
wout
;
}
// process remain row (odd, last row)
for
(;
h
<
h_limit
;
h
++
)
{
// run 0 or 1 time
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
#ifdef __aarch64__
float32x4_t
dsum
=
vpaddq_f32
(
dr00
,
dr01
);
#else
float32x2_t
dsuml
=
vpadd_f32
(
vget_low_f32
(
dr00
),
vget_high_f32
(
dr00
));
float32x2_t
dsumh
=
vpadd_f32
(
vget_low_f32
(
dr01
),
vget_high_f32
(
dr01
));
float32x4_t
dsum
=
vcombine_f32
(
dsuml
,
dsumh
);
#endif
float32x4_t
res
=
vmulq_f32
(
dsum
,
vcoef_2
);
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
res
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1
\n
"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3
\n
"
"vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr_out
]
"+r"
(
dr_out
),
[
vcoef_2
]
"+w"
(
vcoef_2
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"w"
(
vcoef_2
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r0
[
w
+
1
])
*
coef_2
;
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
r0
[
w
]
*
coef_1
;
}
}
}
}
}
void
pooling3x3s1p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
1
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_unroll_size
=
((
win
-
2
)
>>
2
)
<<
2
;
int
w_unroll_remain
=
win
-
2
-
w_unroll_size
;
const
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
2
;
// w_unroll_size / 4
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
1
;
// left
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dr_out
+
1
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
// right
float
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch
+=
wout
;
int
h
=
0
;
for
(;
h
<
hin
-
2
;
h
+=
1
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
maxr0
,
maxr1
),
maxr2
);
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
dr2
=
r2
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1
\n
"
"vmax.f32 q7, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d16, d2, d6 @max r0_5678,r1_5678
\n
"
"vmax.f32 q3, q7, q4 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d16, d10 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q3, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"sub %[dr2], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r2
[
j
+
w
],
r2
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
r2
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
// right
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r2
[
win
-
2
],
r2
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
r0
=
r1
;
r1
=
r2
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
// the last two line
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
maxr0
,
maxr1
);
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remian
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
}
}
}
void
pooling3x3s1p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
1
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_unroll_size
=
((
win
-
2
)
>>
2
)
<<
2
;
int
w_unroll_remain
=
win
-
2
-
w_unroll_size
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
const
float
coef_4
=
exclusive
?
1.
f
/
4.
f
:
coef
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
float32x4_t
vcoef_4
=
vdupq_n_f32
(
coef_4
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
2
;
// w_unroll_size / 4
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
1
;
// left
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dr_out
+
1
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q5, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add + 3456
\n
"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef_6
]
"+w"
(
vcoef_6
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef_6
;
}
// right
float
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_4
;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch
+=
wout
;
int
h
=
0
;
for
(;
h
<
hin
-
2
;
h
+=
1
)
{
// deal with left pad
float
maxr0
=
r0
[
0
]
+
r0
[
1
];
float
maxr1
=
r1
[
0
]
+
r1
[
1
];
float
maxr2
=
r2
[
0
]
+
r2
[
1
];
dout_ch
[
0
]
=
(
maxr0
+
maxr1
+
maxr2
)
*
coef_6
;
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
dr2
=
r2
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vadd.f32 q7, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d16, d2, d6 @max r0_5678,r1_5678
\n
"
"vadd.f32 q3, q7, q4 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d16, d10 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q3, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q3, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add+3456
\n
"
"vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"sub %[dr2], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef
]
"+w"
(
vcoef
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
tmp_sum
+=
(
r2
[
j
+
w
+
1
]
+
r2
[
j
+
w
+
2
]);
tmp_sum
+=
r2
[
j
+
w
];
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef
;
}
// right
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
tmp
+=
(
r2
[
win
-
2
]
+
r2
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_6
;
r0
=
r1
;
r1
=
r2
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
// last line
float
maxr0
=
(
r0
[
0
]
+
r0
[
1
]);
float
maxr1
=
(
r1
[
0
]
+
r1
[
1
]);
dout_ch
[
0
]
=
(
maxr0
+
maxr1
)
*
coef_4
;
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q5, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add + 3456
\n
"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef_6
]
"+w"
(
vcoef_6
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef_6
;
}
// right
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_4
;
}
}
}
void
pooling3x3s2p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
((
w_even
-
1
)
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
1
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
-
padding
;
int
h_remain
=
h_needed
-
h_limit
-
padding
;
int
w_in_2
=
win
<<
1
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
1
;
int
cnt
=
1
;
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
// first row with zero pad
#if __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr0
=
dr0
+
1
;
dr1
=
dr1
+
1
;
dr_out
=
dr_out
+
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max r0_1234,r1_1234
\n
"
"vmax.f32 q7, q1, q4 @max r0_5678,r1_5678
\n
"
"vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext max_6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w, 8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num, #1
\n
"
"bne 1b @bne s3_max_loop
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
// int w = w_even - 1;
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r1
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
int
h
=
2
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
maxr0
,
maxr1
),
maxr2
);
#if __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
vmax_9101112
=
vmaxq_f32
(
vmax_9101112
,
vr2_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
minval
,
vr2
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
vmax1
=
vmaxq_f32
(
vmax1
,
vr2
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
float32x2_t
vmax
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
dr2
=
(
r2
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vmax.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vmax.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vmax.f32 q11, q2, q5 @max q1,q1,q3
\n
"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234
\n
"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678
\n
"
"vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345
\n
"
"vext.f32 q2, q3, q1, #1 @vext 6789
\n
"
"vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789
\n
"
"vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"sub %[dr2], #16 @add w,8
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmov.f32 s11,s10 @movs11,s10
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0, d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
tmp
=
std
::
max
(
tmp
,
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
int
hstart
=
(
h
>>
1
)
*
stride
-
padding
;
int
hend
=
std
::
min
(
std
::
min
(
hstart
+
kernel
,
hin
+
padding
),
hin
);
if
(
hstart
==
hend
-
1
)
{
// only one lline
dout_ch
[
0
]
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
#if __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vmax_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vmax_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vmax_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
float32x2_t
vmax
=
vpmax_f32
(
vget_low_f32
(
vr0
),
vget_high_f32
(
vr0
));
vmax
=
vpmax_f32
(
vmax
,
vmax
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vext.f32 q4, q0, q1, #1 @vmax_2345
\n
"
"vext.f32 q5, q1, q2, #1 @vmax_6789
\n
"
"vpmax.f32 d12, d0, d1 @vmax_12_34
\n
"
"vpmax.f32 d14, d2, d3 @vmax_56_78
\n
"
"vpmax.f32 d13, d8, d9 @vmax_23_45
\n
"
"vpmax.f32 d15, d10, d11 @vmax_67_89
\n
"
"vmax.f32 d0, d12, d13 @12_34,23_45
\n
"
"vmax.f32 d1, d14, d15 @56_78,67_89
\n
"
"sub %[dr0], #16 @add w,6
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vmov.f32 s3,s2 @movs3, s2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,2
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
r0
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
else
{
// two lines
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext q1,6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3, s2
\n
"
"vmov.f32 s7,s6 @movs7, s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
}
}
}
}
void
pooling3x3s2p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
((
w_even
-
1
)
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
1
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
-
padding
;
int
h_remain
=
h_needed
-
h_limit
-
padding
;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_1
=
exclusive
?
1.
f
:
coef
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
const
float
coef_3
=
exclusive
?
1.
f
/
3.
f
:
coef
;
const
float
coef_4
=
exclusive
?
1.
f
/
4.
f
:
coef
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_1
=
vdupq_n_f32
(
coef_1
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
float32x4_t
vcoef_3
=
vdupq_n_f32
(
coef_3
);
float32x4_t
vcoef_4
=
vdupq_n_f32
(
coef_4
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
1
;
int
cnt
=
1
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr0
=
dr0
+
1
;
dr1
=
dr1
+
1
;
dr_out
=
dr_out
+
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @max r0_1234,r1_1234
\n
"
"vadd.f32 q7, q1, q4 @max r0_5678,r1_5678
\n
"
"vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234, 2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678, 4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456, sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789, sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0, d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
// int w = w_even - 1;
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
// std::numeric_limits<float>::min();
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r1
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
int
h
=
2
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
sum0
=
r0
[
0
]
+
r0
[
1
];
float
sum1
=
r1
[
0
]
+
r1
[
1
];
float
sum2
=
r2
[
0
]
+
r2
[
1
];
dout_ch
[
0
]
=
(
sum0
+
sum1
+
sum2
)
*
coef_6
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
vsum_9101112
=
vaddq_f32
(
vsum_9101112
,
vr2_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
0.
f
,
vr2
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
vsum1
=
vaddq_f32
(
vsum1
,
vr2
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
float32x2_t
vsum
=
vpadd_f32
(
vsum2
,
vsum2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef
;
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
dr2
=
(
r2
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vadd.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vadd.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vadd.f32 q11, q2, q5 @max q1,q1,q3
\n
"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678
\n
"
"vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"sub %[dr2], #16 @add w, 8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef
]
"+w"
(
vcoef
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
3.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]
+
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
int
hstart
=
(
h
>>
1
)
*
stride
-
padding
;
int
hend
=
std
::
min
(
std
::
min
(
hstart
+
kernel
,
hin
+
padding
),
hin
);
if
(
hstart
==
hend
-
1
)
{
// only one line
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
])
*
coef_2
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vsum_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vsum_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vsum_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
float32x2_t
vsum
=
vpadd_f32
(
vget_low_f32
(
vr0
),
vget_high_f32
(
vr0
));
vsum
=
vpadd_f32
(
vsum
,
vsum
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef_3
;
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_3] @mul
\n
"
"sub %[dr0], #16 @add w,6
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_3] @mul
\n
"
"sub %[dr0], #8 @add w,2
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_3
]
"+w"
(
vcoef_3
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
1.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
r0
[
i
];
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
else
{
// two lines
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @add q0,q0,q2 1234
\n
"
"vadd.f32 q7, q1, q4 @add q1,q1,q3 5678
\n
"
"vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
}
}
}
}
void
pooling3x3s2p0_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
((
w_limit
-
1
)
>>
1
)
<<
1
;
int
h_even
=
((
h_limit
-
1
)
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
;
int
h_remain
=
h_needed
-
h_limit
;
int
w_in_2
=
win
<<
1
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
// w = w_in - 8;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
0
;
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
// first row with zero pad
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
// dout_channel += w_out;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w
=
0
;
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
vmax_9101112
=
vmaxq_f32
(
vmax_9101112
,
vr2_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
minval
,
vr2
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
vmax1
=
vmaxq_f32
(
vmax1
,
vr2
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
float32x2_t
vmax
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
dr2
=
r2
;
// (r2 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vmax.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vmax.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vmax.f32 d22, d4, d10 @max q1,q1,q3
\n
"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234
\n
"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678
\n
"
"vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345
\n
"
"vext.f32 q2, q3, q1, #1 @vext 6789
\n
"
"vpmax.f32 d10, d0, d1 @pmax "
"d10,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d12, d6, d7 @pmax "
"d12,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d11, d8, d9 @pmax "
"d11,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d13, d4, d5 @pmax "
"d13,vmax_6789,vmax_6789
\n
"
"vmax.f32 d0, d10, d11 @pmax "
"d0,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d1, d12, d13 @pmax "
"d1,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"sub %[dr2], #8 @add w,8
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmov.f32 s11,s10 @movs11,s10
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
tmp
=
std
::
max
(
tmp
,
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin);
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
#ifdef __aarch64__
w
=
0
;
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext q1,6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
}
}
}
void
pooling3x3s2p0_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
((
w_limit
-
1
)
>>
1
)
<<
1
;
int
h_even
=
((
h_limit
-
1
)
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
;
int
h_remain
=
h_needed
-
h_limit
;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
// w = w_in - 8;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// LOG(INFO) << "h: " << h <<", dr0:" << r0 << ", dr1: " << r1 <<
// ",dr2: " <<r2; deal with left pad float sum0 = r0[0] + r0[1]; float
// sum1 = r1[0] + r1[1]; float sum2 = r2[0] + r2[1]; dout_channel[0] =
// (sum0 + sum1 + sum2) / 9.f;
#ifdef __aarch64__
int
w
=
0
;
int
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
vsum_9101112
=
vaddq_f32
(
vsum_9101112
,
vr2_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
0.
f
,
vr2
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
vsum1
=
vaddq_f32
(
vsum1
,
vr2
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
float32x2_t
vsum
=
vpadd_f32
(
vsum2
,
vsum2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef
;
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
dr2
=
r2
;
// (r2 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num, 0
\n
"
"ble loop3_ave_p0 @ble exit
\n
"
"s3_ave_loop_mid_p0: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1
\n
"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2
\n
"
"vadd.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vadd.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vadd.f32 d22, d4, d10 @max q1,q1,q3
\n
"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678
\n
"
"vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234, 2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678, 4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456, sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789, sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"sub %[dr2], #8 @add w,8
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid
\n
"
"loop3_ave_p0: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble exit1_ave_p0 @ble exit1
\n
"
"s3_ave_loop_mid_1_p0: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1
\n
"
"exit1_ave_p0: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef
]
"+w"
(
vcoef
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
3.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]
+
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h),
// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] +
// r1[2]) / 9.f;
#ifdef __aarch64__
int
w
=
0
;
int
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 2f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"2: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0
\n
"
"ble 3f @ble exit
\n
"
"4: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 4b @bne s3_max_loop_bot_1
\n
"
"3: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/pooling.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
// !pooling fp32 Op
void
pooling_basic
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
global_pooling
,
bool
exclusive
,
bool
adaptive
,
bool
ceil_mode
,
bool
use_quantizer
,
const
std
::
string
&
pooling_type
);
void
pooling_global_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling_global_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling2x2s2_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling2x2s2_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s1p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s1p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s2p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s2p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s2p0_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s2p0_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/scale.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/scale.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
,
float
scale
,
float
bias
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
float32x4_t
vbias
=
vdupq_n_f32
(
bias
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias
,
din0
,
vscale
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias
,
din1
,
vscale
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias
,
din2
,
vscale
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias
,
din3
,
vscale
);
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
+
bias
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
inner_dim
>>
4
;
int
remain
=
inner_dim
%
16
;
int
size
=
inner_dim
*
scale_dim
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
size
;
float
*
dout_ptr_n
=
dout
+
n
*
size
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
scale_dim
;
i
++
)
{
const
float
*
din_ptr
=
din_ptr_n
+
i
*
inner_dim
;
float
*
dout_ptr
=
dout_ptr_n
+
i
*
inner_dim
;
float
scale
=
scale_data
[
i
];
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
float
bias
=
bias_data
[
i
];
float32x4_t
vbias
=
vdupq_n_f32
(
bias
);
for
(
int
j
=
0
;
j
<
cnt
;
j
++
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias
,
din0
,
vscale
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias
,
din1
,
vscale
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias
,
din2
,
vscale
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias
,
din3
,
vscale
);
din_ptr
+=
16
;
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
dout_ptr
+=
16
;
}
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
+
bias
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
scale_dim
>>
4
;
int
remain
=
scale_dim
%
16
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
scale_dim
;
float
*
dout_ptr_n
=
dout
+
n
*
scale_dim
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
int
idx
=
i
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vscale0
=
vld1q_f32
(
scale_ptr
);
float32x4_t
vbias0
=
vld1q_f32
(
bias_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vscale1
=
vld1q_f32
(
scale_ptr
+
4
);
float32x4_t
vbias1
=
vld1q_f32
(
bias_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
vscale2
=
vld1q_f32
(
scale_ptr
+
8
);
float32x4_t
vbias2
=
vld1q_f32
(
bias_ptr
+
8
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias0
,
din0
,
vscale0
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias1
,
din1
,
vscale1
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vscale3
=
vld1q_f32
(
scale_ptr
+
12
);
float32x4_t
vbias3
=
vld1q_f32
(
bias_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias2
,
din2
,
vscale2
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias3
,
din3
,
vscale3
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
}
int
idx
=
cnt
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
(
*
scale_ptr
)
+
(
*
bias_ptr
);
dout_ptr
++
;
din_ptr
++
;
scale_ptr
++
;
bias_ptr
++
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/scale.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
scale
,
float
bias
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/softmax.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/softmax.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
softmax_basic
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
compute_size
;
++
i
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner8_axis4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
3
;
int
remain
=
compute_size
%
8
;
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
8
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float32x4_t
vdata01
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vdata11
=
vld1q_f32
(
din_ptr1
+
4
);
float32x4_t
vdata21
=
vld1q_f32
(
din_ptr2
+
4
);
float32x4_t
vdata31
=
vld1q_f32
(
din_ptr3
+
4
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float32x4_t
vmax11
=
vmaxq_f32
(
vdata01
,
vdata11
);
float32x4_t
vmax21
=
vmaxq_f32
(
vdata21
,
vdata31
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
float32x4_t
vmax_1
=
vmaxq_f32
(
vmax11
,
vmax21
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum01
=
exp_ps
(
vsubq_f32
(
vdata01
,
vmax_1
));
float32x4_t
vsum11
=
exp_ps
(
vsubq_f32
(
vdata11
,
vmax_1
));
float32x4_t
vsum21
=
exp_ps
(
vsubq_f32
(
vdata21
,
vmax_1
));
float32x4_t
vsum31
=
exp_ps
(
vsubq_f32
(
vdata31
,
vmax_1
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum_11
=
vaddq_f32
(
vsum01
,
vsum11
);
float32x4_t
vsum_21
=
vaddq_f32
(
vsum21
,
vsum31
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vsum111
=
vaddq_f32
(
vsum_11
,
vsum_21
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
float32x4_t
vinf1
=
div_ps
(
vone
,
vsum111
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vsum01
=
vmulq_f32
(
vsum01
,
vinf1
);
vsum11
=
vmulq_f32
(
vsum11
,
vinf1
);
vsum21
=
vmulq_f32
(
vsum21
,
vinf1
);
vsum31
=
vmulq_f32
(
vsum31
,
vinf1
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
vst1q_f32
(
dout_ptr0
+
4
,
vsum01
);
vst1q_f32
(
dout_ptr1
+
4
,
vsum11
);
vst1q_f32
(
dout_ptr2
+
4
,
vsum21
);
vst1q_f32
(
dout_ptr3
+
4
,
vsum31
);
}
int
i
=
cmp_cnt
*
8
;
if
(
remain
>
4
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
i
+=
4
;
}
for
(;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner4_axis4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
2
;
int
remain
=
compute_size
%
4
;
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
4
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
}
int
i
=
cmp_cnt
*
8
;
for
(;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner8
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
3
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
8
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
const
float
*
din_ptr
=
din
+
real_index
;
float32x4_t
vmax
=
vld1q_f32
(
din_ptr
);
float32x4_t
vmax2
=
vld1q_f32
(
din_ptr
+
4
);
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
din_ptr
+=
inner_num
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr
+
4
);
vmax
=
vmaxq_f32
(
vmax
,
vdata
);
vmax2
=
vmaxq_f32
(
vmax2
,
vdata2
);
}
// sub, exp and sum
din_ptr
=
din
+
real_index
;
float
*
dout_ptr
=
dout
+
real_index
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vsum
=
exp_ps
(
vsubq_f32
(
vdata
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax2
));
din_ptr
+=
inner_num
;
vst1q_f32
(
dout_ptr
,
vsum
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
dout_ptr
+=
inner_num
;
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr
+
4
);
vdata0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
vdata1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax2
));
din_ptr
+=
inner_num
;
vsum
=
vaddq_f32
(
vsum
,
vdata0
);
vsum2
=
vaddq_f32
(
vsum2
,
vdata1
);
vst1q_f32
(
dout_ptr
,
vdata0
);
vst1q_f32
(
dout_ptr
+
4
,
vdata1
);
dout_ptr
+=
inner_num
;
}
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
float32x4_t
vinf2
=
div_ps
(
vone
,
vsum2
);
dout_ptr
=
dout
+
real_index
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
dout_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
dout_ptr
+
4
);
vdata0
=
vmulq_f32
(
vdata0
,
vinf
);
vdata1
=
vmulq_f32
(
vdata1
,
vinf2
);
vst1q_f32
(
dout_ptr
,
vdata0
);
vst1q_f32
(
dout_ptr
+
4
,
vdata1
);
dout_ptr
+=
inner_num
;
}
}
for
(
int
i
=
cmp_cnt
*
8
;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
2
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
4
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// float max_data = din[real_index];
const
float
*
din_ptr
=
din
+
real_index
;
float32x4_t
vmax
=
vld1q_f32
(
din_ptr
);
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
din_ptr
+=
inner_num
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
vmax
=
vmaxq_f32
(
vmax
,
vdata
);
}
// sub, exp and sum
din_ptr
=
din
+
real_index
;
float
*
dout_ptr
=
dout
+
real_index
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vsum
=
exp_ps
(
vsubq_f32
(
vdata
,
vmax
));
din_ptr
+=
inner_num
;
vst1q_f32
(
dout_ptr
,
vsum
);
dout_ptr
+=
inner_num
;
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
// real_index += inner_num;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
vdata0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
din_ptr
+=
inner_num
;
vsum
=
vaddq_f32
(
vsum
,
vdata0
);
vst1q_f32
(
dout_ptr
,
vdata0
);
dout_ptr
+=
inner_num
;
}
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
dout_ptr
=
dout
+
real_index
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
dout_ptr
);
vdata0
=
vmulq_f32
(
vdata0
,
vinf
);
vst1q_f32
(
dout_ptr
,
vdata0
);
dout_ptr
+=
inner_num
;
}
}
for
(
int
i
=
cmp_cnt
*
4
;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner1_large_axis
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
outer_size
,
const
int
axis_size
)
{
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
outer_size
;
++
i
)
{
const
float
*
din_ptr
=
din
+
i
*
axis_size
;
float
*
dout_ptr
=
dout
+
i
*
axis_size
;
const
float
*
din_max_ptr
=
din_ptr
;
int
nn
=
axis_size
>>
2
;
// get max
float32x4_t
vmax
=
vld1q_f32
(
din_max_ptr
);
din_max_ptr
+=
4
;
int
j
=
1
;
for
(;
j
<
nn
;
++
j
)
{
vmax
=
vmaxq_f32
(
vmax
,
vld1q_f32
(
din_max_ptr
));
din_max_ptr
+=
4
;
}
float32x2_t
vhmax
=
vmax_f32
(
vget_high_f32
(
vmax
),
vget_low_f32
(
vmax
));
float
max_data
=
std
::
max
(
vget_lane_f32
(
vhmax
,
0
),
vget_lane_f32
(
vhmax
,
1
));
for
(
j
=
4
*
j
;
j
<
axis_size
;
++
j
)
{
max_data
=
std
::
max
(
max_data
,
din_max_ptr
[
0
]);
din_max_ptr
++
;
}
// sub, exp and sum
const
float
*
din_sum_ptr
=
din_ptr
;
float
*
dout_sum_ptr
=
dout_ptr
;
vmax
=
vdupq_n_f32
(
max_data
);
float32x4_t
vsub_exp
=
exp_ps
(
vsubq_f32
(
vld1q_f32
(
din_sum_ptr
),
vmax
));
float32x4_t
vsum
=
vsub_exp
;
vst1q_f32
(
dout_sum_ptr
,
vsub_exp
);
din_sum_ptr
+=
4
;
dout_sum_ptr
+=
4
;
j
=
1
;
for
(;
j
<
nn
;
++
j
)
{
vsub_exp
=
exp_ps
(
vsubq_f32
(
vld1q_f32
(
din_sum_ptr
),
vmax
));
vst1q_f32
(
dout_sum_ptr
,
vsub_exp
);
vsum
=
vaddq_f32
(
vsum
,
vsub_exp
);
din_sum_ptr
+=
4
;
dout_sum_ptr
+=
4
;
}
float32x2_t
vhsum
=
vadd_f32
(
vget_high_f32
(
vsum
),
vget_low_f32
(
vsum
));
float
sum_data
=
vget_lane_f32
(
vhsum
,
0
)
+
vget_lane_f32
(
vhsum
,
1
);
for
(
j
=
4
*
j
;
j
<
axis_size
;
++
j
)
{
dout_sum_ptr
[
0
]
=
expf
(
din_sum_ptr
[
0
]
-
max_data
);
sum_data
+=
dout_sum_ptr
[
0
];
din_sum_ptr
++
;
dout_sum_ptr
++
;
}
float
sum_inv
=
1.
f
/
sum_data
;
float
*
dout_res_ptr
=
dout_ptr
;
float32x4_t
vinv
=
vdupq_n_f32
(
sum_inv
);
// get softmax result
j
=
0
;
for
(;
j
<
nn
;
++
j
)
{
float32x4_t
vout
=
vld1q_f32
(
dout_res_ptr
);
float32x4_t
vres
=
vmulq_f32
(
vout
,
vinv
);
vst1q_f32
(
dout_res_ptr
,
vres
);
dout_res_ptr
+=
4
;
}
for
(
j
=
nn
*
4
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
*=
sum_inv
;
}
}
}
template
<
>
void
softmax_inner1_small_axis
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
outer_size
,
const
int
axis_size
)
{
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
outer_size
;
++
i
)
{
const
float
*
din_ptr
=
din
+
i
*
axis_size
;
float
*
dout_ptr
=
dout
+
i
*
axis_size
;
// get max
float
max_data
=
din_ptr
[
0
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
max_data
=
std
::
max
(
max_data
,
din_ptr
[
j
]);
}
// sub, exp and sum
float
sum_data
=
0.
f
;
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
=
expf
(
din_ptr
[
j
]
-
max_data
);
sum_data
+=
dout_ptr
[
j
];
}
float
sum_inv
=
1.
f
/
sum_data
;
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
*=
sum_inv
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/softmax.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
softmax_basic
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner8_axis4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner4_axis4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner8
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner1_large_axis
(
const
T
*
din
,
T
*
dout
,
const
int
outer_size
,
const
int
axis_size
);
template
<
typename
T
>
void
softmax_inner1_small_axis
(
const
T
*
din
,
T
*
dout
,
const
int
outer_size
,
const
int
axis_size
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/split.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/split.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
split_cpy
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
split
<
float
>
(
const
float
*
din
,
const
std
::
vector
<
lite
::
Tensor
*>&
dout
,
const
int
axis
,
const
std
::
vector
<
int
>&
in_strides
)
{
int
input_offset
=
0
;
for
(
auto
out
:
dout
)
{
auto
out_dim
=
out
->
dims
();
std
::
vector
<
int
>
out_strides
(
out_dim
.
size
());
out_strides
[
out_dim
.
size
()
-
1
]
=
out_dim
[
out_dim
.
size
()
-
1
];
for
(
int
i
=
out_dim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
out_strides
[
i
]
=
out_strides
[
i
+
1
]
*
out_dim
[
i
];
}
float
*
out_data
=
out
->
mutable_data
<
float
>
();
int
before
=
out_strides
[
0
]
/
out_strides
[
axis
];
int
in_after
=
in_strides
[
axis
];
int
out_after
=
out_strides
[
axis
];
for
(
int
i
=
0
;
i
<
before
;
++
i
)
{
split_cpy
(
din
+
input_offset
+
i
*
in_after
,
out_data
+
i
*
out_after
,
out_after
);
}
input_offset
+=
out_strides
[
axis
];
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/split.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
split_cpy
(
const
T
*
din
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
split
(
const
T
*
din
,
const
std
::
vector
<
lite
::
Tensor
*>&
dout
,
const
int
axis
,
const
std
::
vector
<
int
>&
in_strides
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录