Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3a631fbb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a631fbb
编写于
7月 01, 2019
作者:
C
Chunwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
prune
上级
0f9e7057
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
0 addition
and
4995 deletion
+0
-4995
paddle/fluid/lite/arm/math/activation.cc
paddle/fluid/lite/arm/math/activation.cc
+0
-520
paddle/fluid/lite/arm/math/activation.h
paddle/fluid/lite/arm/math/activation.h
+0
-50
paddle/fluid/lite/arm/math/concat.cc
paddle/fluid/lite/arm/math/concat.cc
+0
-59
paddle/fluid/lite/arm/math/concat.h
paddle/fluid/lite/arm/math/concat.h
+0
-34
paddle/fluid/lite/arm/math/dropout.cc
paddle/fluid/lite/arm/math/dropout.cc
+0
-93
paddle/fluid/lite/arm/math/dropout.h
paddle/fluid/lite/arm/math/dropout.h
+0
-32
paddle/fluid/lite/arm/math/elementwise.cc
paddle/fluid/lite/arm/math/elementwise.cc
+0
-261
paddle/fluid/lite/arm/math/elementwise.h
paddle/fluid/lite/arm/math/elementwise.h
+0
-39
paddle/fluid/lite/arm/math/pooling.cc
paddle/fluid/lite/arm/math/pooling.cc
+0
-2859
paddle/fluid/lite/arm/math/pooling.h
paddle/fluid/lite/arm/math/pooling.h
+0
-73
paddle/fluid/lite/arm/math/scale.cc
paddle/fluid/lite/arm/math/scale.cc
+0
-169
paddle/fluid/lite/arm/math/scale.h
paddle/fluid/lite/arm/math/scale.h
+0
-36
paddle/fluid/lite/arm/math/softmax.cc
paddle/fluid/lite/arm/math/softmax.cc
+0
-601
paddle/fluid/lite/arm/math/softmax.h
paddle/fluid/lite/arm/math/softmax.h
+0
-52
paddle/fluid/lite/arm/math/split.cc
paddle/fluid/lite/arm/math/split.cc
+0
-82
paddle/fluid/lite/arm/math/split.h
paddle/fluid/lite/arm/math/split.h
+0
-35
未找到文件。
paddle/fluid/lite/arm/math/activation.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/activation.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
act_relu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
vr0
=
vmaxq_f32
(
vr0
,
vzero
);
vr1
=
vmaxq_f32
(
vr1
,
vzero
);
vr2
=
vmaxq_f32
(
vr2
,
vzero
);
vr3
=
vmaxq_f32
(
vr3
,
vzero
);
vst1q_f32
(
ptr_out_thread
,
vr0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vr3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vmax.f32 q8, q0, %q[vzero] @ relu
\n
"
"vmax.f32 q9, q1, %q[vzero] @ relu
\n
"
"vmax.f32 q10, q2, %q[vzero] @ relu
\n
"
"vmax.f32 q11, q3, %q[vzero] @ relu
\n
"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
0.
f
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
0.
f
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_relu_neg
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
negative_slope
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
valpha
=
vdupq_n_f32
(
negative_slope
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
uint32x4_t
vm0
=
vcgeq_f32
(
vr0
,
vzero
);
uint32x4_t
vm1
=
vcgeq_f32
(
vr1
,
vzero
);
uint32x4_t
vm2
=
vcgeq_f32
(
vr2
,
vzero
);
uint32x4_t
vm3
=
vcgeq_f32
(
vr3
,
vzero
);
float32x4_t
vn0
=
vmulq_f32
(
vr0
,
valpha
);
float32x4_t
vn1
=
vmulq_f32
(
vr1
,
valpha
);
float32x4_t
vn2
=
vmulq_f32
(
vr2
,
valpha
);
float32x4_t
vn3
=
vmulq_f32
(
vr3
,
valpha
);
float32x4_t
vo0
=
vbslq_f32
(
vm0
,
vr0
,
vn0
);
float32x4_t
vo1
=
vbslq_f32
(
vm1
,
vr1
,
vn1
);
float32x4_t
vo2
=
vbslq_f32
(
vm2
,
vr2
,
vn2
);
float32x4_t
vo3
=
vbslq_f32
(
vm3
,
vr3
,
vn3
);
vst1q_f32
(
ptr_out_thread
,
vo0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vcge.f32 q8, q0, %q[vzero] @ get mask
\n
"
"vcge.f32 q9, q1, %q[vzero] @ get mask
\n
"
"vcge.f32 q10, q2, %q[vzero] @ get mask
\n
"
"vcge.f32 q11, q3, %q[vzero] @ get mask
\n
"
"vmul.f32 q4, q0, %q[valpha] @ get neg data
\n
"
"vmul.f32 q5, q1, %q[valpha] @ get neg data
\n
"
"vmul.f32 q6, q2, %q[valpha] @ get neg data
\n
"
"vmul.f32 q7, q3, %q[valpha] @ get neg data
\n
"
"vbit q4, q0, q8 @ bitsel, insert q0 to q4, "
"if q8 is 1
\n
"
"vbit q5, q1, q9 @ bitsel, insert q1 to q5, "
"if q9 is 1
\n
"
"vbit q6, q2, q10 @ bitsel, insert q2 to q6, "
"if q10 is 1
\n
"
"vbit q7, q3, q11 @ bitsel, insert q3 to q7, "
"if q11 is 1
\n
"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
valpha
]
"w"
(
valpha
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
ptr_in_thread
[
0
]
*
negative_slope
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
in_ptr_remain
[
0
]
*
negative_slope
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_clipped_relu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
coef
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt
=
nums_per_thread
>>
4
;
int
neon_loop_remain
=
nums_per_thread
-
(
neon_loop_cnt
<<
4
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vclip
=
vdupq_n_f32
(
coef
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
int
cnt
=
neon_loop_cnt
;
#ifdef __aarch64__
for
(
int
num
=
0
;
num
<
neon_loop_cnt
;
++
num
)
{
float32x4_t
vr0
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr1
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr2
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vr3
=
vld1q_f32
(
ptr_in_thread
);
ptr_in_thread
+=
4
;
float32x4_t
vt0
=
vmaxq_f32
(
vr0
,
vzero
);
float32x4_t
vt1
=
vmaxq_f32
(
vr1
,
vzero
);
float32x4_t
vt2
=
vmaxq_f32
(
vr2
,
vzero
);
float32x4_t
vt3
=
vmaxq_f32
(
vr3
,
vzero
);
float32x4_t
vo0
=
vminq_f32
(
vt0
,
vclip
);
float32x4_t
vo1
=
vminq_f32
(
vt1
,
vclip
);
float32x4_t
vo2
=
vminq_f32
(
vt2
,
vclip
);
float32x4_t
vo3
=
vminq_f32
(
vt3
,
vclip
);
vst1q_f32
(
ptr_out_thread
,
vo0
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo1
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo2
);
ptr_out_thread
+=
4
;
vst1q_f32
(
ptr_out_thread
,
vo3
);
ptr_out_thread
+=
4
;
}
#else
if
(
cnt
>
0
)
{
asm
volatile
(
"1: @ loop header
\n
"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0
\n
"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0
\n
"
"vmax.f32 q8, q0, %q[vzero] @ relu
\n
"
"vmax.f32 q9, q1, %q[vzero] @ relu
\n
"
"vmax.f32 q10, q2, %q[vzero] @ relu
\n
"
"vmax.f32 q11, q3, %q[vzero] @ relu
\n
"
"vmin.f32 q4, q8, %q[vclip] @ clip relu
\n
"
"vmin.f32 q5, q9, %q[vclip] @ clip relu
\n
"
"vmin.f32 q6, q10, %q[vclip] @ clip relu
\n
"
"vmin.f32 q7, q11, %q[vclip] @ clip relu
\n
"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer
\n
"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"bne 1b @ jump to main loop start "
"point
\n
"
:
[
dout
]
"+r"
(
ptr_out_thread
),
[
din
]
"+r"
(
ptr_in_thread
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
vclip
]
"w"
(
vclip
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
}
#endif
for
(
int
j
=
0
;
j
<
neon_loop_remain
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
>
0.
f
?
ptr_in_thread
[
0
]
:
0.
f
;
ptr_out_thread
[
0
]
=
ptr_out_thread
[
0
]
<
coef
?
ptr_out_thread
[
0
]
:
coef
;
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
out_ptr_remain
=
dout
+
threads
*
nums_per_thread
;
const
float
*
in_ptr_remain
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
out_ptr_remain
[
0
]
=
in_ptr_remain
[
0
]
>
0.
f
?
in_ptr_remain
[
0
]
:
0.
f
;
out_ptr_remain
[
0
]
=
out_ptr_remain
[
0
]
<
coef
?
out_ptr_remain
[
0
]
:
coef
;
in_ptr_remain
++
;
out_ptr_remain
++
;
}
}
template
<
>
void
act_prelu
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_size
,
int
channel_size
,
int
inner_size
,
bool
channel_shared
,
float
*
channel_slope
,
int
threads
)
{
int
stride_size
=
inner_size
*
channel_size
;
int
cnt
=
inner_size
>>
4
;
int
remain
=
inner_size
&
15
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
outer_size
;
n
++
)
{
const
float
*
data_in_batch
=
din
+
n
*
stride_size
;
float
*
data_out_batch
=
dout
+
n
*
stride_size
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
channel_size
;
c
++
)
{
const
float
*
data_in_c
=
data_in_batch
+
c
*
inner_size
;
float
*
data_out_c
=
data_out_batch
+
c
*
inner_size
;
float
slope
=
channel_shared
?
channel_slope
[
0
]
:
channel_slope
[
c
];
float32x4_t
vslope
=
vdupq_n_f32
(
slope
);
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
cnt
;
++
i
)
{
float32x4_t
vr0
=
vld1q_f32
(
data_in_c
);
float32x4_t
vr1
=
vld1q_f32
(
data_in_c
+
4
);
float32x4_t
vr2
=
vld1q_f32
(
data_in_c
+
8
);
float32x4_t
vr3
=
vld1q_f32
(
data_in_c
+
12
);
uint32x4_t
vm0
=
vcltq_f32
(
vr0
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm1
=
vcltq_f32
(
vr1
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm2
=
vcltq_f32
(
vr2
,
vzero
);
// vr0 <= vzero
uint32x4_t
vm3
=
vcltq_f32
(
vr3
,
vzero
);
// vr0 <= vzero
float32x4_t
vo0
=
vmulq_f32
(
vr0
,
vslope
);
// vr0 * vslope
float32x4_t
vo1
=
vmulq_f32
(
vr1
,
vslope
);
// vr0 * vslope
float32x4_t
vo2
=
vmulq_f32
(
vr2
,
vslope
);
// vr0 * vslope
float32x4_t
vo3
=
vmulq_f32
(
vr3
,
vslope
);
// vr0 * vslope
float32x4_t
vos0
=
vbslq_f32
(
vm0
,
vo0
,
vr0
);
float32x4_t
vos1
=
vbslq_f32
(
vm1
,
vo1
,
vr1
);
float32x4_t
vos2
=
vbslq_f32
(
vm2
,
vo2
,
vr2
);
float32x4_t
vos3
=
vbslq_f32
(
vm3
,
vo3
,
vr3
);
vst1q_f32
(
data_out_c
,
vos0
);
vst1q_f32
(
data_out_c
+
4
,
vos1
);
vst1q_f32
(
data_out_c
+
8
,
vos2
);
vst1q_f32
(
data_out_c
+
12
,
vos3
);
data_in_c
+=
16
;
data_out_c
+=
16
;
}
#else
int
cnt_loop
=
cnt
;
if
(
cnt_loop
>
0
)
{
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load "
"input to q0, q1
\n
"
"pld [%[ptr_in]] @ preload
\n
"
"pld [%[ptr_in], #64] @ preload
\n
"
"pld [%[ptr_in], #128] @ preload
\n
"
"pld [%[ptr_in], #192] @ preload
\n
"
"1: @main loop
\n
"
"vld1.32 {d4-d7}, [%[ptr_in]]! @ load input to "
"q2, q3
\n
"
"vclt.f32 q8, q0, %q[vzero] @vcle q0 <= vzero
\n
"
"vclt.f32 q9, q1, %q[vzero] @vcle q1 <= vzero
\n
"
"vmul.f32 q10, q0, %q[vslope] @vmul q0 * vslope
\n
"
"vmul.f32 q11, q1, %q[vslope] @vmul q1 * vslope
\n
"
"vclt.f32 q12, q2, %q[vzero] @vcle q2 <= vzero
\n
"
"vclt.f32 q13, q3, %q[vzero] @vcle q3 <= vzero
\n
"
"vmul.f32 q14, q2, %q[vslope] @vmul q2 * vslope
\n
"
"vmul.f32 q15, q3, %q[vslope] @vmul q3 * vslope
\n
"
"vbif.32 q10, q0, q8 @vbit q10, q0, q8
\n
"
"vbif.32 q11, q1, q9 @vbit q11, q1, q9
\n
"
"vbif.32 q14, q2, q12 @vbit q14, q2, "
"q12
\n
"
"vbif.32 q15, q3, q13 @vbit q15, q3, "
"q13
\n
"
"subs %[cnt], #1 @subs nn, 1
\n
"
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load input to "
"q0, q1
\n
"
"vst1.f32 {d20-d23}, [%[dout]]! @store data
\n
"
"vst1.f32 {d28-d31}, [%[dout]]! @store data
\n
"
"bne 1b @bne nn
\n
"
"sub %[ptr_in], #32 @ ptr-32
\n
"
:
[
ptr_in
]
"+r"
(
data_in_c
),
[
cnt
]
"+r"
(
cnt_loop
),
[
dout
]
"+r"
(
data_out_c
)
:
[
vzero
]
"w"
(
vzero
),
[
vslope
]
"w"
(
vslope
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
#endif // __aarch64__
for
(
int
i
=
remain
;
i
>
0
;
i
--
)
{
*
(
data_out_c
++
)
=
data_in_c
[
0
]
>
0.
f
?
data_in_c
[
0
]
:
data_in_c
[
0
]
*
slope
;
data_in_c
++
;
}
}
}
}
template
<
>
void
act_sigmoid
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
float32x4_t
exp_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
recip
=
vdupq_n_f32
(
0.0
f
);
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
exp_vec
=
exp_ps
(
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
)));
exp_vec
=
vaddq_f32
(
exp_vec
,
vdupq_n_f32
(
1.0
f
));
recip
=
vrecpeq_f32
(
exp_vec
);
recip
=
vmulq_f32
(
vrecpsq_f32
(
exp_vec
,
recip
),
recip
);
recip
=
vmulq_f32
(
vrecpsq_f32
(
exp_vec
,
recip
),
recip
);
vst1q_f32
(
ptr_out_thread
,
recip
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
1.
f
/
(
1
+
expf
(
-
ptr_in_thread
[
0
]));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
1.
f
/
(
1
+
expf
(
-
ptr_in
[
0
]));
ptr_in
++
;
ptr_out
++
;
}
}
// tanh : (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template
<
>
void
act_tanh
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
float32x4_t
exp_plus_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_minus_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_sum_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
exp_diff_vec
=
vdupq_n_f32
(
0.0
f
);
float32x4_t
recip
=
vdupq_n_f32
(
0.0
f
);
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
exp_plus_vec
=
exp_ps
(
vld1q_f32
(
ptr_in_thread
));
exp_minus_vec
=
exp_ps
(
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
)));
exp_sum_vec
=
vaddq_f32
(
exp_plus_vec
,
exp_minus_vec
);
exp_diff_vec
=
vsubq_f32
(
exp_plus_vec
,
exp_minus_vec
);
recip
=
div_ps
(
exp_diff_vec
,
exp_sum_vec
);
vst1q_f32
(
ptr_out_thread
,
recip
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
(
expf
(
ptr_in_thread
[
0
])
-
expf
(
-
ptr_in_thread
[
0
]))
/
(
expf
(
ptr_in_thread
[
0
])
+
expf
(
-
ptr_in_thread
[
0
]));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
(
expf
(
ptr_in
[
0
])
-
expf
(
-
ptr_in
[
0
]))
/
(
expf
(
ptr_in
[
0
])
+
expf
(
-
ptr_in
[
0
]));
ptr_in
++
;
ptr_out
++
;
}
}
// swish: x /(1 + exp(-(b * x)))
template
<
>
void
act_swish
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
size
,
const
float
coef
,
int
threads
)
{
int
nums_per_thread
=
size
/
threads
;
int
remain
=
size
-
threads
*
nums_per_thread
;
int
neon_loop_cnt_dim4
=
nums_per_thread
>>
2
;
int
neon_loop_remain_dim4
=
nums_per_thread
-
(
neon_loop_cnt_dim4
<<
2
);
const
float
beta
=
coef
;
float32x4_t
vbeta
=
vdupq_n_f32
(
beta
);
float32x4_t
vone
=
vdupq_n_f32
(
1.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
threads
;
++
i
)
{
const
float
*
ptr_in_thread
=
din
+
i
*
nums_per_thread
;
float
*
ptr_out_thread
=
dout
+
i
*
nums_per_thread
;
for
(
int
k
=
0
;
k
<
neon_loop_cnt_dim4
;
++
k
)
{
float32x4_t
va
=
vld1q_f32
(
ptr_in_thread
);
// x
float32x4_t
vb
=
vnegq_f32
(
vld1q_f32
(
ptr_in_thread
));
// -x
float32x4_t
vsum
=
vmulq_f32
(
vb
,
vbeta
);
vsum
=
exp_ps
(
vsum
);
float32x4_t
vc
=
vaddq_f32
(
vone
,
vsum
);
float32x4_t
vrst
=
div_ps
(
va
,
vc
);
vst1q_f32
(
ptr_out_thread
,
vrst
);
ptr_out_thread
+=
4
;
ptr_in_thread
+=
4
;
}
for
(
int
j
=
0
;
j
<
neon_loop_remain_dim4
;
++
j
)
{
ptr_out_thread
[
0
]
=
ptr_in_thread
[
0
]
/
(
1.0
+
expf
(
-
ptr_in_thread
[
0
]
*
beta
));
ptr_in_thread
++
;
ptr_out_thread
++
;
}
}
float
*
ptr_out
=
dout
+
threads
*
nums_per_thread
;
const
float
*
ptr_in
=
din
+
threads
*
nums_per_thread
;
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
ptr_out
[
0
]
=
ptr_in
[
0
]
/
(
1.0
+
expf
(
-
ptr_in
[
0
]
*
beta
));
ptr_in
++
;
ptr_out
++
;
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/activation.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
act_relu
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_relu_neg
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
negative_slope
,
int
threads
);
template
<
typename
T
>
void
act_clipped_relu
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
coef
,
int
threads
);
template
<
typename
T
>
void
act_prelu
(
const
T
*
din
,
T
*
dout
,
int
outer_size
,
int
channel_size
,
int
inner_size
,
bool
channel_shared
,
float
*
channel_slope
,
int
threads
);
template
<
typename
T
>
void
act_sigmoid
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_tanh
(
const
T
*
din
,
T
*
dout
,
int
size
,
int
threads
);
template
<
typename
T
>
void
act_swish
(
const
T
*
din
,
T
*
dout
,
int
size
,
const
float
coef
,
int
threads
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/concat.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/concat.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
const
int
axis
,
lite
::
Tensor
*
output
)
{
size_t
num
=
input
.
size
();
int
rows
=
1
;
auto
dim_0
=
input
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
out_rows
=
rows
,
out_cols
=
0
;
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
]
->
numel
()
/
rows
;
out_cols
+=
t_cols
;
input_cols
[
i
]
=
t_cols
;
}
// computation
for
(
int
k
=
0
;
k
<
out_rows
;
++
k
)
{
float
*
dst_ptr
=
output
->
mutable_data
<
float
>
()
+
k
*
out_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
input_cols
[
j
];
const
float
*
src_prt
=
input
[
j
]
->
data
<
float
>
()
+
k
*
col_len
;
std
::
memcpy
(
dst_ptr
+
col_idx
,
src_prt
,
sizeof
(
float
)
*
col_len
);
col_idx
+=
col_len
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/concat.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
concat_func
(
const
std
::
vector
<
lite
::
Tensor
*>
&
input
,
const
int
axis
,
lite
::
Tensor
*
output
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/dropout.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/dropout.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
dropout_down
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
,
float
prob
)
{
const
float
scale
=
1.0
f
-
prob
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vmul0
=
vmulq_f32
(
din0
,
vscale
);
float32x4_t
vmul1
=
vmulq_f32
(
din1
,
vscale
);
float32x4_t
vmul2
=
vmulq_f32
(
din2
,
vscale
);
float32x4_t
vmul3
=
vmulq_f32
(
din3
,
vscale
);
vst1q_f32
(
dout_ptr
,
vmul0
);
vst1q_f32
(
dout_ptr
+
4
,
vmul1
);
vst1q_f32
(
dout_ptr
+
8
,
vmul2
);
vst1q_f32
(
dout_ptr
+
12
,
vmul3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
dropout_up
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/dropout.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
dropout_down
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
prob
);
template
<
typename
T
>
void
dropout_up
(
const
T
*
din
,
T
*
dout
,
int
num
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/elementwise.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/elementwise.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
elementwise_add
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
dinx_ptr
=
dinx
+
(
i
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
dinx0
=
vld1q_f32
(
dinx_ptr
);
float32x4_t
dinx1
=
vld1q_f32
(
dinx_ptr
+
4
);
float32x4_t
dinx2
=
vld1q_f32
(
dinx_ptr
+
8
);
float32x4_t
dinx3
=
vld1q_f32
(
dinx_ptr
+
12
);
float32x4_t
diny0
=
vld1q_f32
(
diny_ptr
);
float32x4_t
diny1
=
vld1q_f32
(
diny_ptr
+
4
);
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
dinx0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx3
=
vaddq_f32
(
dinx3
,
diny3
);
vst1q_f32
(
dout_ptr
,
dinx0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
dinx_ptr
+
*
diny_ptr
;
dout_ptr
++
;
dinx_ptr
++
;
diny_ptr
++
;
}
}
}
template
<
>
void
elementwise_add_relu
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
dinx_ptr
=
dinx
+
(
i
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
dinx0
=
vld1q_f32
(
dinx_ptr
);
float32x4_t
dinx1
=
vld1q_f32
(
dinx_ptr
+
4
);
float32x4_t
dinx2
=
vld1q_f32
(
dinx_ptr
+
8
);
float32x4_t
dinx3
=
vld1q_f32
(
dinx_ptr
+
12
);
float32x4_t
diny0
=
vld1q_f32
(
diny_ptr
);
float32x4_t
diny1
=
vld1q_f32
(
diny_ptr
+
4
);
float32x4_t
diny2
=
vld1q_f32
(
diny_ptr
+
8
);
float32x4_t
diny3
=
vld1q_f32
(
diny_ptr
+
12
);
dinx0
=
vaddq_f32
(
dinx0
,
diny0
);
dinx1
=
vaddq_f32
(
dinx1
,
diny1
);
dinx2
=
vaddq_f32
(
dinx2
,
diny2
);
dinx3
=
vaddq_f32
(
dinx3
,
diny3
);
// relu
dinx0
=
vmaxq_f32
(
dinx0
,
vzero
);
dinx1
=
vmaxq_f32
(
dinx1
,
vzero
);
dinx2
=
vmaxq_f32
(
dinx2
,
vzero
);
dinx3
=
vmaxq_f32
(
dinx3
,
vzero
);
vst1q_f32
(
dout_ptr
,
dinx0
);
vst1q_f32
(
dout_ptr
+
4
,
dinx1
);
vst1q_f32
(
dout_ptr
+
8
,
dinx2
);
vst1q_f32
(
dout_ptr
+
12
,
dinx3
);
}
if
(
remain
>
0
)
{
const
float
*
dinx_ptr
=
dinx
+
(
cnt
<<
4
);
const
float
*
diny_ptr
=
diny
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
float
tmp
=
*
dinx_ptr
+
*
diny_ptr
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
dinx_ptr
++
;
diny_ptr
++
;
}
}
}
template
<
>
void
elementwise_add_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
*
dout_ptr
=
*
din_ptr
+
diny_data
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
template
<
>
void
elementwise_add_relu_broadcast
<
float
>
(
const
float
*
dinx
,
const
float
*
diny
,
float
*
dout
,
int
batch
,
int
channels
,
int
num
)
{
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#pragma omp parallel for collapse(2)
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
for
(
int
j
=
0
;
j
<
channels
;
++
j
)
{
int
offset
=
(
i
*
channels
+
j
)
*
num
;
const
float
*
din_ptr
=
dinx
+
offset
;
const
float
diny_data
=
diny
[
j
];
float
*
dout_ptr
=
dout
+
offset
;
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
rb
=
vdupq_n_f32
(
diny_data
);
for
(
int
k
=
0
;
k
<
cnt
;
++
k
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
din2
=
vaddq_f32
(
din2
,
rb
);
din3
=
vaddq_f32
(
din3
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
din2
=
vmaxq_f32
(
din2
,
vzero
);
din3
=
vmaxq_f32
(
din3
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
din_ptr
+=
16
;
dout_ptr
+=
16
;
}
if
(
remain
>=
8
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
din0
=
vaddq_f32
(
din0
,
rb
);
din1
=
vaddq_f32
(
din1
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
din1
=
vmaxq_f32
(
din1
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
din_ptr
+=
8
;
dout_ptr
+=
8
;
remain
-=
8
;
}
if
(
remain
>=
4
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
din0
=
vaddq_f32
(
din0
,
rb
);
// relu
din0
=
vmaxq_f32
(
din0
,
vzero
);
vst1q_f32
(
dout_ptr
,
din0
);
din_ptr
+=
4
;
dout_ptr
+=
4
;
remain
-=
4
;
}
if
(
remain
>
0
)
{
for
(
int
p
=
0
;
p
<
remain
;
p
++
)
{
float
tmp
=
*
din_ptr
+
diny_data
;
*
dout_ptr
=
tmp
>
0.
f
?
tmp
:
0.
f
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/elementwise.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
elementwise_add
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_relu
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
elementwise_add_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
template
<
typename
T
>
void
elementwise_add_relu_broadcast
(
const
T
*
dinx
,
const
T
*
diny
,
T
*
dout
,
int
batch
,
int
channels
,
int
num
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/pooling.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/pooling.h"
#include <algorithm>
#include <limits>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
pooling_basic
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
global_pooling
,
bool
exclusive
,
bool
adaptive
,
bool
ceil_mode
,
bool
use_quantizer
,
const
std
::
string
&
pooling_type
)
{
// no need to pad input tensor, border is zero pad inside this function
int
kernel_h
=
ksize
[
0
];
int
kernel_w
=
ksize
[
1
];
int
stride_h
=
strides
[
0
];
int
stride_w
=
strides
[
1
];
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
1
];
int
size_channel_in
=
win
*
hin
;
int
size_channel_out
=
wout
*
hout
;
if
(
global_pooling
)
{
if
(
pooling_type
==
"max"
)
{
// Pooling_max
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
float
tmp1
=
din_ch
[
0
];
for
(
int
i
=
0
;
i
<
size_channel_in
;
++
i
)
{
float
tmp2
=
din_ch
[
i
];
tmp1
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
}
dout_batch
[
c
]
=
tmp1
;
}
}
}
else
if
(
pooling_type
==
"avg"
)
{
// Pooling_average_include_padding
// Pooling_average_exclude_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
size_channel_in
;
++
i
)
{
sum
+=
din_ch
[
i
];
}
dout_batch
[
c
]
=
sum
/
size_channel_in
;
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported pooling type: "
<<
pooling_type
;
}
}
else
{
if
(
pooling_type
==
"max"
)
{
// Pooling_max
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
tmp1
=
din_ch
[
hstart
*
win
+
wstart
];
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
float
tmp2
=
din_ch
[
h
*
win
+
w
];
tmp1
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
}
}
dout_row
[
j
]
=
tmp1
;
}
dout_row
+=
wout
;
}
}
}
}
else
if
(
pooling_type
==
"avg"
)
{
if
(
exclusive
)
{
// Pooling_average_exclude_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
sum
=
0.
f
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
sum
+=
din_ch
[
h
*
win
+
w
];
}
}
dout_row
[
j
]
=
sum
/
pool_size
;
}
dout_row
+=
wout
;
}
}
}
}
else
{
// Pooling_average_include_padding
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_ch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_row
=
dout_ch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
for
(
int
i
=
0
;
i
<
hout
;
i
++
)
{
for
(
int
j
=
0
;
j
<
wout
;
j
++
)
{
int
hstart
=
i
*
stride_h
-
pad_h
;
int
wstart
=
j
*
stride_w
-
pad_w
;
int
hend
=
std
::
min
(
hstart
+
kernel_h
,
hin
+
pad_h
);
int
wend
=
std
::
min
(
wstart
+
kernel_w
,
win
+
pad_w
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
std
::
min
(
hend
,
hin
);
wend
=
std
::
min
(
wend
,
win
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
if
(
pool_size
==
0
)
continue
;
float
sum
=
0.
f
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
sum
+=
din_ch
[
h
*
win
+
w
];
}
}
dout_row
[
j
]
=
sum
/
(
kernel_w
*
kernel_h
);
}
dout_row
+=
wout
;
}
}
}
}
}
else
{
LOG
(
FATAL
)
<<
"unsupported pooling type: "
<<
pooling_type
;
}
}
}
void
pooling_global_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
size_channel_in
=
win
*
hin
;
int
cnt
=
size_channel_in
/
8
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
++
c
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
int
i
=
0
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
float32x4_t
vmax
=
vdupq_n_f32
(
minval
);
#ifdef __aarch64__
for
(;
i
<
cnt
;
i
++
)
{
float32x4_t
vdin1
=
vld1q_f32
(
din_ch
);
vmax
=
vmaxq_f32
(
vdin1
,
vmax
);
float32x4_t
vdin2
=
vld1q_f32
(
din_ch
+
4
);
vmax
=
vmaxq_f32
(
vmax
,
vdin2
);
din_ch
+=
8
;
}
#else
int
cnt_num
=
cnt
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"max_loop: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch
\n
"
"vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch
\n
"
"vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data
\n
"
"vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne max_loop @bne cnt_num
\n
"
:
[
din_ch
]
"+r"
(
din_ch
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vmax
]
"+w"
(
vmax
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
);
}
#endif // __aarch64__
float32x2_t
vmax_tmp
=
vmax_f32
(
vget_low_f32
(
vmax
),
vget_high_f32
(
vmax
));
float
tmp1
=
vget_lane_f32
(
vmax_tmp
,
0
);
float
tmp2
=
vget_lane_f32
(
vmax_tmp
,
1
);
float
max_tmp
=
tmp1
>
tmp2
?
tmp1
:
tmp2
;
for
(
i
=
cnt
*
8
;
i
<
size_channel_in
;
++
i
)
{
/* code */
max_tmp
=
max_tmp
>
din_ch
[
0
]
?
max_tmp
:
din_ch
[
0
];
din_ch
++
;
}
dout_batch
[
c
]
=
max_tmp
;
}
}
}
void
pooling_global_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
size_channel_in
=
win
*
hin
;
int
cnt
=
size_channel_in
/
4
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
// in address
int
i
=
0
;
float32x4_t
vsum
=
vdupq_n_f32
(
0.0
f
);
#ifdef __aarch64__
for
(;
i
<
cnt
;
i
++
)
{
vsum
=
vaddq_f32
(
vld1q_f32
(
din_ch
),
vsum
);
din_ch
+=
4
;
}
#else
int
cnt_num
=
cnt
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"add_loop: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch
\n
"
"vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne add_loop @bne num
\n
"
:
[
din_ch
]
"+r"
(
din_ch
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vsum
]
"+w"
(
vsum
)
:
:
"cc"
,
"memory"
,
"q0"
);
}
#endif // __aarch64__
float32x2_t
vsum_tmp
=
vadd_f32
(
vget_low_f32
(
vsum
),
vget_high_f32
(
vsum
));
float
sum
=
vget_lane_f32
(
vsum_tmp
,
0
)
+
vget_lane_f32
(
vsum_tmp
,
1
);
for
(
i
=
cnt
*
4
;
i
<
size_channel_in
;
i
++
)
{
sum
+=
din_ch
[
0
];
din_ch
++
;
}
dout_batch
[
c
]
=
sum
/
size_channel_in
;
}
}
}
void
pooling2x2s2_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
2
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
);
int
h_needed
=
(
hout
<<
1
);
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
// int w_unroll_remain = w_even - w_unroll_size;
int
w_in_2
=
win
<<
1
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
dr10
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
dr11
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
dmax1
=
vmaxq_f32
(
dr00
,
dr10
);
float32x4_t
dmax2
=
vmaxq_f32
(
dr01
,
dr11
);
#ifdef __aarch64__
float32x4_t
dmax
=
vpmaxq_f32
(
dmax1
,
dmax2
);
#else
float32x2_t
dmaxl
=
vpmax_f32
(
vget_low_f32
(
dmax1
),
vget_high_f32
(
dmax1
));
float32x2_t
dmaxh
=
vpmax_f32
(
vget_low_f32
(
dmax2
),
vget_high_f32
(
dmax2
));
float32x4_t
dmax
=
vcombine_f32
(
dmaxl
,
dmaxh
);
#endif
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
dmax
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"s2_max_loop: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vmax.f32 q1, q1, q3 @max q1,q1,q2
\n
"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1
\n
"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne s2_max_loop @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
std
::
max
(
std
::
max
(
r0
[
w
],
r0
[
w
+
1
]),
std
::
max
(
r1
[
w
],
r1
[
w
+
1
]));
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
std
::
max
(
r0
[
w
],
r1
[
w
]);
}
r0
+=
w_in_2
;
// << 1;
r1
+=
w_in_2
;
// << 1;
dout_ch
+=
wout
;
}
// process remain row (odd, last row)
for
(;
h
<
h_limit
;
h
++
)
{
// run 0 or 1 time
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
#ifdef __aarch64__
float32x4_t
dmax
=
vpmaxq_f32
(
dr00
,
dr01
);
#else
float32x2_t
dmaxl
=
vpmax_f32
(
vget_low_f32
(
dr00
),
vget_high_f32
(
dr00
));
float32x2_t
dmaxh
=
vpmax_f32
(
vget_low_f32
(
dr01
),
vget_high_f32
(
dr01
));
float32x4_t
dmax
=
vcombine_f32
(
dmaxl
,
dmaxh
);
#endif
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
dmax
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"s2_max_loop1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1
\n
"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne s2_max_loop1 @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
std
::
max
(
r0
[
w
],
r0
[
w
+
1
]);
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
r0
[
w
];
}
}
}
}
}
void
pooling2x2s2_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
2
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
);
int
h_needed
=
(
hout
<<
1
);
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
// int w_unroll_remain = w_even - w_unroll_size;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
4.
f
;
const
float
coef_1
=
exclusive
?
1.
f
:
coef
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_1
=
vdupq_n_f32
(
coef_1
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
dr10
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
dr11
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
dsum1
=
vaddq_f32
(
dr00
,
dr10
);
float32x4_t
dsum2
=
vaddq_f32
(
dr01
,
dr11
);
#ifdef __aarch64__
float32x4_t
dsum
=
vpaddq_f32
(
dsum1
,
dsum2
);
#else
float32x2_t
dsuml
=
vpadd_f32
(
vget_low_f32
(
dsum1
),
vget_high_f32
(
dsum1
));
float32x2_t
dsumh
=
vpadd_f32
(
vget_low_f32
(
dsum2
),
vget_high_f32
(
dsum2
));
float32x4_t
dsum
=
vcombine_f32
(
dsuml
,
dsumh
);
#endif
float32x4_t
res
=
vmulq_f32
(
dsum
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
res
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q2
\n
"
"vadd.f32 q1, q1, q3 @add q1,q1,q2
\n
"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1
\n
"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3
\n
"
"vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
vcoef
]
"+w"
(
vcoef
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"w"
(
vcoef
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r0
[
w
+
1
]
+
r1
[
w
]
+
r1
[
w
+
1
])
*
coef
;
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r1
[
w
])
*
coef_2
;
}
r0
+=
w_in_2
;
// << 1;
r1
+=
w_in_2
;
// << 1;
dout_ch
+=
wout
;
}
// process remain row (odd, last row)
for
(;
h
<
h_limit
;
h
++
)
{
// run 0 or 1 time
int
w
=
0
;
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
dr00
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
dr01
=
vld1q_f32
(
&
r0
[
w
+
4
]);
#ifdef __aarch64__
float32x4_t
dsum
=
vpaddq_f32
(
dr00
,
dr01
);
#else
float32x2_t
dsuml
=
vpadd_f32
(
vget_low_f32
(
dr00
),
vget_high_f32
(
dr00
));
float32x2_t
dsumh
=
vpadd_f32
(
vget_low_f32
(
dr01
),
vget_high_f32
(
dr01
));
float32x4_t
dsum
=
vcombine_f32
(
dsuml
,
dsumh
);
#endif
float32x4_t
res
=
vmulq_f32
(
dsum
,
vcoef_2
);
vst1q_f32
(
&
dout_ch
[
w
>>
1
],
res
);
}
#else
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
int
cnt_num
=
w_unroll_size
>>
3
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0
\n
"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1
\n
"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3
\n
"
"vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2
\n
"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne cnt_num
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr_out
]
"+r"
(
dr_out
),
[
vcoef_2
]
"+w"
(
vcoef_2
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"w"
(
vcoef_2
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
);
}
w
=
w_unroll_size
;
#endif // __aarch64__
for
(;
w
<
w_even
;
w
+=
2
)
{
dout_ch
[
w
>>
1
]
=
(
r0
[
w
]
+
r0
[
w
+
1
])
*
coef_2
;
}
for
(;
w
<
w_limit
;
++
w
)
{
// run 0 or 1 time
dout_ch
[
w
>>
1
]
=
r0
[
w
]
*
coef_1
;
}
}
}
}
}
void
pooling3x3s1p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
1
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_unroll_size
=
((
win
-
2
)
>>
2
)
<<
2
;
int
w_unroll_remain
=
win
-
2
-
w_unroll_size
;
const
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
2
;
// w_unroll_size / 4
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
1
;
// left
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dr_out
+
1
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
// right
float
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch
+=
wout
;
int
h
=
0
;
for
(;
h
<
hin
-
2
;
h
+=
1
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
maxr0
,
maxr1
),
maxr2
);
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
dr2
=
r2
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1
\n
"
"vmax.f32 q7, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d16, d2, d6 @max r0_5678,r1_5678
\n
"
"vmax.f32 q3, q7, q4 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d16, d10 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q3, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"sub %[dr2], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r2
[
j
+
w
],
r2
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
r2
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
// right
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r2
[
win
-
2
],
r2
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
r0
=
r1
;
r1
=
r2
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
// the last two line
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
maxr0
,
maxr1
);
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_3456
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
2
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_34_56
=
vpmax_f32
(
vget_low_f32
(
vmax_3456
),
vget_high_f32
(
vmax_3456
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_234_456
=
vmax_f32
(
vmax_23_45
,
vmax_34_56
);
float32x4_t
vmax
=
vdupq_n_f32
(
vget_lane_f32
(
vmax_123_345
,
0
));
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
0
),
vmax
,
1
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_123_345
,
1
),
vmax
,
2
);
vmax
=
vsetq_lane_f32
(
vget_lane_f32
(
vmax_234_456
,
1
),
vmax
,
3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vmax
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234
\n
"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345
\n
"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456
\n
"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
// swap
"vmov.f32 s0, s17 @mov
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s0 @mov
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remian
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_max
=
std
::
max
(
r0
[
j
+
w
],
r1
[
j
+
w
]);
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
1
],
r1
[
j
+
w
+
1
]));
tmp_max
=
std
::
max
(
tmp_max
,
std
::
max
(
r0
[
j
+
w
+
2
],
r1
[
j
+
w
+
2
]));
dout_ch
[
j
+
w
+
1
]
=
tmp_max
;
}
tmp
=
std
::
max
(
r0
[
win
-
2
],
r1
[
win
-
2
]);
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
win
-
1
],
r1
[
win
-
1
]));
dout_ch
[
wout
-
1
]
=
tmp
;
}
}
}
void
pooling3x3s1p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
1
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_unroll_size
=
((
win
-
2
)
>>
2
)
<<
2
;
int
w_unroll_remain
=
win
-
2
-
w_unroll_size
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
const
float
coef_4
=
exclusive
?
1.
f
/
4.
f
:
coef
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
float32x4_t
vcoef_4
=
vdupq_n_f32
(
coef_4
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
2
;
// w_unroll_size / 4
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
1
;
// left
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dr_out
+
1
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q5, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add + 3456
\n
"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef_6
]
"+w"
(
vcoef_6
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef_6
;
}
// right
float
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_4
;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch
+=
wout
;
int
h
=
0
;
for
(;
h
<
hin
-
2
;
h
+=
1
)
{
// deal with left pad
float
maxr0
=
r0
[
0
]
+
r0
[
1
];
float
maxr1
=
r1
[
0
]
+
r1
[
1
];
float
maxr2
=
r2
[
0
]
+
r2
[
1
];
dout_ch
[
0
]
=
(
maxr0
+
maxr1
+
maxr2
)
*
coef_6
;
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
dr2
=
r2
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vadd.f32 q7, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d16, d2, d6 @max r0_5678,r1_5678
\n
"
"vadd.f32 q3, q7, q4 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d16, d10 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q3, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q3, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add+3456
\n
"
"vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"sub %[dr2], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef
]
"+w"
(
vcoef
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
tmp_sum
+=
(
r2
[
j
+
w
+
1
]
+
r2
[
j
+
w
+
2
]);
tmp_sum
+=
r2
[
j
+
w
];
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef
;
}
// right
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
tmp
+=
(
r2
[
win
-
2
]
+
r2
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_6
;
r0
=
r1
;
r1
=
r2
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
// last line
float
maxr0
=
(
r0
[
0
]
+
r0
[
1
]);
float
maxr1
=
(
r1
[
0
]
+
r1
[
1
]);
dout_ch
[
0
]
=
(
maxr0
+
maxr1
)
*
coef_4
;
#ifdef __aarch64__
w
=
0
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
4
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum
=
vaddq_f32
(
vsum
,
vsum_3456
);
vsum
=
vmulq_f32
(
vsum
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vsum
);
cnt
+=
4
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
r0
;
dr1
=
r1
;
cnt_num
=
w_unroll_size
>>
2
;
if
(
cnt_num
>
0
)
{
asm
volatile
(
"1: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234
\n
"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345
\n
"
"vext.f32 q2, q5, q6, #2 @vext max_3456
\n
"
"vadd.f32 q1, q5, q0 @add 1234+2345
\n
"
"vadd.f32 q1, q1, q2 @add + 3456
\n
"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f
\n
"
"sub %[dr0], #8 @sub w,8
\n
"
"sub %[dr1], #8 @sub w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s1_max_loop
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
vcoef_6
]
"+w"
(
vcoef_6
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
);
}
#endif
// remain
w
=
w_unroll_size
;
for
(
int
j
=
0
;
j
<
w_unroll_remain
;
j
++
)
{
float
tmp_sum
=
r0
[
j
+
w
]
+
r1
[
j
+
w
];
tmp_sum
+=
(
r0
[
j
+
w
+
1
]
+
r1
[
j
+
w
+
1
]);
tmp_sum
+=
(
r0
[
j
+
w
+
2
]
+
r1
[
j
+
w
+
2
]);
dout_ch
[
j
+
w
+
1
]
=
tmp_sum
*
coef_6
;
}
// right
tmp
=
r0
[
win
-
2
]
+
r1
[
win
-
2
];
tmp
+=
(
r0
[
win
-
1
]
+
r1
[
win
-
1
]);
dout_ch
[
wout
-
1
]
=
tmp
*
coef_4
;
}
}
}
void
pooling3x3s2p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
((
w_even
-
1
)
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
1
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
-
padding
;
int
h_remain
=
h_needed
-
h_limit
-
padding
;
int
w_in_2
=
win
<<
1
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
1
;
int
cnt
=
1
;
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
// first row with zero pad
#if __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr0
=
dr0
+
1
;
dr1
=
dr1
+
1
;
dr_out
=
dr_out
+
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max r0_1234,r1_1234
\n
"
"vmax.f32 q7, q1, q4 @max r0_5678,r1_5678
\n
"
"vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext max_6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w, 8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num, #1
\n
"
"bne 1b @bne s3_max_loop
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
// int w = w_even - 1;
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r1
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
int
h
=
2
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
maxr0
,
maxr1
),
maxr2
);
#if __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
vmax_9101112
=
vmaxq_f32
(
vmax_9101112
,
vr2_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
minval
,
vr2
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
vmax1
=
vmaxq_f32
(
vmax1
,
vr2
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
float32x2_t
vmax
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
dr2
=
(
r2
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vmax.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vmax.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vmax.f32 q11, q2, q5 @max q1,q1,q3
\n
"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234
\n
"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678
\n
"
"vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345
\n
"
"vext.f32 q2, q3, q1, #1 @vext 6789
\n
"
"vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789
\n
"
"vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"sub %[dr2], #16 @add w,8
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmov.f32 s11,s10 @movs11,s10
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0, d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
tmp
=
std
::
max
(
tmp
,
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
int
hstart
=
(
h
>>
1
)
*
stride
-
padding
;
int
hend
=
std
::
min
(
std
::
min
(
hstart
+
kernel
,
hin
+
padding
),
hin
);
if
(
hstart
==
hend
-
1
)
{
// only one lline
dout_ch
[
0
]
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
#if __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vmax_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vmax_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vmax_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
float32x2_t
vmax
=
vpmax_f32
(
vget_low_f32
(
vr0
),
vget_high_f32
(
vr0
));
vmax
=
vpmax_f32
(
vmax
,
vmax
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vext.f32 q4, q0, q1, #1 @vmax_2345
\n
"
"vext.f32 q5, q1, q2, #1 @vmax_6789
\n
"
"vpmax.f32 d12, d0, d1 @vmax_12_34
\n
"
"vpmax.f32 d14, d2, d3 @vmax_56_78
\n
"
"vpmax.f32 d13, d8, d9 @vmax_23_45
\n
"
"vpmax.f32 d15, d10, d11 @vmax_67_89
\n
"
"vmax.f32 d0, d12, d13 @12_34,23_45
\n
"
"vmax.f32 d1, d14, d15 @56_78,67_89
\n
"
"sub %[dr0], #16 @add w,6
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vmov.f32 s3,s2 @movs3, s2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,2
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
r0
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
else
{
// two lines
dout_ch
[
0
]
=
std
::
max
(
std
::
max
(
r0
[
0
],
r0
[
1
]),
std
::
max
(
r1
[
0
],
r1
[
1
]));
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext q1,6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3, s2
\n
"
"vmov.f32 s7,s6 @movs7, s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
}
}
}
}
void
pooling3x3s2p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
1
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
(
w_limit
>>
1
)
<<
1
;
int
h_even
=
(
h_limit
>>
1
)
<<
1
;
int
w_unroll_size
=
((
w_even
-
1
)
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
1
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
-
padding
;
int
h_remain
=
h_needed
-
h_limit
-
padding
;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_1
=
exclusive
?
1.
f
:
coef
;
const
float
coef_2
=
exclusive
?
1.
f
/
2.
f
:
coef
;
const
float
coef_3
=
exclusive
?
1.
f
/
3.
f
:
coef
;
const
float
coef_4
=
exclusive
?
1.
f
/
4.
f
:
coef
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_1
=
vdupq_n_f32
(
coef_1
);
float32x4_t
vcoef_2
=
vdupq_n_f32
(
coef_2
);
float32x4_t
vcoef_3
=
vdupq_n_f32
(
coef_3
);
float32x4_t
vcoef_4
=
vdupq_n_f32
(
coef_4
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
1
;
int
cnt
=
1
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
// first row with zero pad
#ifdef __aarch64__
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr0
=
dr0
+
1
;
dr1
=
dr1
+
1
;
dr_out
=
dr_out
+
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @max r0_1234,r1_1234
\n
"
"vadd.f32 q7, q1, q4 @max r0_5678,r1_5678
\n
"
"vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234, 2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678, 4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456, sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789, sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @main loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0, d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
// int w = w_even - 1;
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
// std::numeric_limits<float>::min();
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r1
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
int
h
=
2
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
sum0
=
r0
[
0
]
+
r0
[
1
];
float
sum1
=
r1
[
0
]
+
r1
[
1
];
float
sum2
=
r2
[
0
]
+
r2
[
1
];
dout_ch
[
0
]
=
(
sum0
+
sum1
+
sum2
)
*
coef_6
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
vsum_9101112
=
vaddq_f32
(
vsum_9101112
,
vr2_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
0.
f
,
vr2
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
vsum1
=
vaddq_f32
(
vsum1
,
vr2
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
float32x2_t
vsum
=
vpadd_f32
(
vsum2
,
vsum2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef
;
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
dr2
=
(
r2
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vadd.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vadd.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vadd.f32 q11, q2, q5 @max q1,q1,q3
\n
"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678
\n
"
"vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"sub %[dr2], #16 @add w, 8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef
]
"+w"
(
vcoef
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
3.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]
+
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
int
hstart
=
(
h
>>
1
)
*
stride
-
padding
;
int
hend
=
std
::
min
(
std
::
min
(
hstart
+
kernel
,
hin
+
padding
),
hin
);
if
(
hstart
==
hend
-
1
)
{
// only one line
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
])
*
coef_2
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vsum_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vsum_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vsum_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_3
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
float32x2_t
vsum
=
vpadd_f32
(
vget_low_f32
(
vr0
),
vget_high_f32
(
vr0
));
vsum
=
vpadd_f32
(
vsum
,
vsum
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef_3
;
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_3] @mul
\n
"
"sub %[dr0], #16 @add w,6
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_3] @mul
\n
"
"sub %[dr0], #8 @add w,2
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_3
]
"+w"
(
vcoef_3
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
1.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
r0
[
i
];
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
else
{
// two lines
dout_ch
[
0
]
=
(
r0
[
0
]
+
r0
[
1
]
+
r1
[
0
]
+
r1
[
1
])
*
coef_4
;
#ifdef __aarch64__
w
=
1
;
cnt
=
1
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
-
1
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
+
1
;
dr0
=
(
r0
+
1
);
dr1
=
(
r1
+
1
);
cnt_num
=
w_unroll_size
>>
3
;
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @add q0,q0,q2 1234
\n
"
"vadd.f32 q7, q1, q4 @add q1,q1,q3 5678
\n
"
"vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #16 @add w,8
\n
"
"sub %[dr1], #16 @add w,8
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
}
}
}
}
void
pooling3x3s2p0_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
((
w_limit
-
1
)
>>
1
)
<<
1
;
int
h_even
=
((
h_limit
-
1
)
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
;
int
h_remain
=
h_needed
-
h_limit
;
int
w_in_2
=
win
<<
1
;
float
minval
=
std
::
numeric_limits
<
float
>::
lowest
();
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
// w = w_in - 8;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
int
w
=
0
;
int
cnt
=
0
;
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
// first row with zero pad
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
// dout_channel += w_out;
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// deal with left pad
float
maxr0
=
std
::
max
(
r0
[
0
],
r0
[
1
]);
float
maxr1
=
std
::
max
(
r1
[
0
],
r1
[
1
]);
float
maxr2
=
std
::
max
(
r2
[
0
],
r2
[
1
]);
// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w
=
0
;
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
vmax_1234
=
vmaxq_f32
(
vmax_1234
,
vr2_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
vmax_5678
=
vmaxq_f32
(
vmax_5678
,
vr2_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
vmax_9101112
=
vmaxq_f32
(
vmax_9101112
,
vr2_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
minval
,
vr2
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
vmax1
=
vmaxq_f32
(
vmax1
,
vr2
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
float32x2_t
vmax
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
dr2
=
r2
;
// (r2 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1
\n
"
"vmax.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vmax.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vmax.f32 d22, d4, d10 @max q1,q1,q3
\n
"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234
\n
"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678
\n
"
"vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345
\n
"
"vext.f32 q2, q3, q1, #1 @vext 6789
\n
"
"vpmax.f32 d10, d0, d1 @pmax "
"d10,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d12, d6, d7 @pmax "
"d12,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d11, d8, d9 @pmax "
"d11,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d13, d4, d5 @pmax "
"d13,vmax_6789,vmax_6789
\n
"
"vmax.f32 d0, d10, d11 @pmax "
"d0,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d1, d12, d13 @pmax "
"d1,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"sub %[dr2], #8 @add w,8
\n
"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"bne 1b @bne s3_max_loop_mid
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble 4f @ble exit1
\n
"
"2: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmov.f32 s11,s10 @movs11,s10
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vmax.f32 q0, q0, q2 @max q0,q0,q2
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"bne 2b @bne s3_max_loop_mid_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
tmp
=
std
::
max
(
tmp
,
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin);
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
#ifdef __aarch64__
w
=
0
;
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vmax_1234
=
vmaxq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vmax_5678
=
vmaxq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vmax_9101112
=
vmaxq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vmax_2345
=
vextq_f32
(
vmax_1234
,
vmax_5678
,
1
);
float32x4_t
vmax_6789
=
vextq_f32
(
vmax_5678
,
vmax_9101112
,
1
);
float32x2_t
vmax_12_34
=
vpmax_f32
(
vget_low_f32
(
vmax_1234
),
vget_high_f32
(
vmax_1234
));
float32x2_t
vmax_23_45
=
vpmax_f32
(
vget_low_f32
(
vmax_2345
),
vget_high_f32
(
vmax_2345
));
float32x2_t
vmax_56_78
=
vpmax_f32
(
vget_low_f32
(
vmax_5678
),
vget_high_f32
(
vmax_5678
));
float32x2_t
vmax_67_89
=
vpmax_f32
(
vget_low_f32
(
vmax_6789
),
vget_high_f32
(
vmax_6789
));
float32x2_t
vmax_123_345
=
vmax_f32
(
vmax_12_34
,
vmax_23_45
);
float32x2_t
vmax_567_789
=
vmax_f32
(
vmax_56_78
,
vmax_67_89
);
vst1_f32
(
&
dout_ch
[
cnt
],
vmax_123_345
);
vst1_f32
(
&
dout_ch
[
cnt
+
2
],
vmax_567_789
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
minval
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
minval
,
vr1
,
3
);
float32x4_t
vmax1
=
vmaxq_f32
(
vr0
,
vr1
);
float32x2_t
vmax2
=
vpmax_f32
(
vget_low_f32
(
vmax1
),
vget_high_f32
(
vmax1
));
vmax2
=
vpmax_f32
(
vmax2
,
vmax2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vmax2
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 3f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345
\n
"
"vext.f32 q1, q7, q8, #1 @vext q1,6789
\n
"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234
\n
"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678
\n
"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345
\n
"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789
\n
"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45
\n
"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"subs %[cnt_num], #1 @subs cnt_num,#1
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"3: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble 4f @ble exit
\n
"
"2: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vmov.f32 s3,s2 @movs3,s2
\n
"
"vmov.f32 s7,s6 @movs7,s6
\n
"
"vmax.f32 q0, q0, q1 @max q0,q0,q1
\n
"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1
\n
"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"bne 2b @bne s3_max_loop_bot_1
\n
"
"4: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp
=
r0
[
wstart
];
// std::numeric_limits<float>::min();
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp
=
std
::
max
(
tmp
,
std
::
max
(
r0
[
i
],
r1
[
i
]));
}
dout_ch
[
w_even
>>
1
]
=
tmp
;
}
}
}
}
}
void
pooling3x3s2p0_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
)
{
int
kernel
=
3
;
int
stride
=
2
;
int
padding
=
0
;
int
size_channel_out
=
wout
*
hout
;
int
size_channel_in
=
win
*
hin
;
int
w_needed
=
(
wout
<<
1
)
+
1
;
int
h_needed
=
(
hout
<<
1
)
+
1
;
int
w_limit
=
w_needed
>
win
?
win
:
w_needed
;
int
h_limit
=
h_needed
>
hin
?
hin
:
h_needed
;
int
w_even
=
((
w_limit
-
1
)
>>
1
)
<<
1
;
int
h_even
=
((
h_limit
-
1
)
>>
1
)
<<
1
;
int
w_unroll_size
=
(
w_even
>>
3
)
<<
3
;
int
w_unroll_remain
=
w_even
-
w_unroll_size
;
int
w_remain
=
w_needed
-
w_limit
;
int
h_remain
=
h_needed
-
h_limit
;
int
w_in_2
=
win
<<
1
;
const
float
coef
=
1.
f
/
9.
f
;
const
float
coef_6
=
exclusive
?
1.
f
/
6.
f
:
coef
;
float32x4_t
vcoef
=
vdupq_n_f32
(
coef
);
float32x4_t
vcoef_6
=
vdupq_n_f32
(
coef_6
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
float
*
dout_batch
=
dout
+
n
*
chout
*
size_channel_out
;
const
float
*
din_batch
=
din
+
n
*
chin
*
size_channel_in
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
chout
;
c
++
)
{
float
*
dout_ch
=
dout_batch
+
c
*
size_channel_out
;
const
float
*
din_ch
=
din_batch
+
c
*
size_channel_in
;
const
float
*
r0
=
din_ch
;
const
float
*
r1
=
r0
+
win
;
const
float
*
r2
=
r1
+
win
;
// w = w_in - 8;
float
*
dr_out
=
dout_ch
;
const
float
*
dr0
=
r0
;
const
float
*
dr1
=
r1
;
const
float
*
dr2
=
r2
;
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
int
h
=
0
;
for
(;
h
<
h_even
;
h
+=
2
)
{
// LOG(INFO) << "h: " << h <<", dr0:" << r0 << ", dr1: " << r1 <<
// ",dr2: " <<r2; deal with left pad float sum0 = r0[0] + r0[1]; float
// sum1 = r1[0] + r1[1]; float sum2 = r2[0] + r2[1]; dout_channel[0] =
// (sum0 + sum1 + sum2) / 9.f;
#ifdef __aarch64__
int
w
=
0
;
int
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vr2_1234
=
vld1q_f32
(
&
r2
[
w
]);
float32x4_t
vr2_5678
=
vld1q_f32
(
&
r2
[
w
+
4
]);
float32x4_t
vr2_9101112
=
vld1q_f32
(
&
r2
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
vsum_1234
=
vaddq_f32
(
vsum_1234
,
vr2_1234
);
vsum_5678
=
vaddq_f32
(
vsum_5678
,
vr2_5678
);
vsum_9101112
=
vaddq_f32
(
vsum_9101112
,
vr2_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr2
=
vld1q_f32
(
&
r2
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
vr2
=
vsetq_lane_f32
(
0.
f
,
vr2
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
vsum1
=
vaddq_f32
(
vsum1
,
vr2
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
float32x2_t
vsum
=
vpadd_f32
(
vsum2
,
vsum2
);
dout_ch
[
cnt
]
=
vget_lane_f32
(
vsum
,
0
)
*
coef
;
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
dr2
=
r2
;
// (r2 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num, 0
\n
"
"ble loop3_ave_p0 @ble exit
\n
"
"s3_ave_loop_mid_p0: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1
\n
"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1
\n
"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2
\n
"
"vadd.f32 q9, q0, q3 @max q0,q0,q2
\n
"
"vadd.f32 q10, q1, q4 @max q1,q1,q3
\n
"
"vadd.f32 d22, d4, d10 @max q1,q1,q3
\n
"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678
\n
"
"vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234, 2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678, 4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456, sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789, sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"sub %[dr2], #8 @add w,8
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid
\n
"
"loop3_ave_p0: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0
\n
"
"ble exit1_ave_p0 @ble exit1
\n
"
"s3_ave_loop_mid_1_p0: @mid loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vadd.f32 q0, q0, q2 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"sub %[dr2], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1
\n
"
"exit1_ave_p0: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr2
]
"+r"
(
dr2
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef
]
"+w"
(
vcoef
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr2
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
3.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]
+
r2
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
// cnt ++;
}
r0
=
r2
;
r1
=
r0
+
win
;
r2
=
r1
+
win
;
dout_ch
+=
wout
;
}
if
(
h_remain
>
0
)
{
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h),
// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] +
// r1[2]) / 9.f;
#ifdef __aarch64__
int
w
=
0
;
int
cnt
=
0
;
for
(;
w
<
w_unroll_size
;
w
+=
8
)
{
float32x4_t
vr0_1234
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr0_5678
=
vld1q_f32
(
&
r0
[
w
+
4
]);
float32x4_t
vr0_9101112
=
vld1q_f32
(
&
r0
[
w
+
8
]);
float32x4_t
vr1_1234
=
vld1q_f32
(
&
r1
[
w
]);
float32x4_t
vr1_5678
=
vld1q_f32
(
&
r1
[
w
+
4
]);
float32x4_t
vr1_9101112
=
vld1q_f32
(
&
r1
[
w
+
8
]);
float32x4_t
vsum_1234
=
vaddq_f32
(
vr0_1234
,
vr1_1234
);
float32x4_t
vsum_5678
=
vaddq_f32
(
vr0_5678
,
vr1_5678
);
float32x4_t
vsum_9101112
=
vaddq_f32
(
vr0_9101112
,
vr1_9101112
);
float32x4_t
vsum_2345
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
1
);
float32x4_t
vsum_3456
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
2
);
float32x4_t
vsum_4567
=
vextq_f32
(
vsum_1234
,
vsum_5678
,
3
);
float32x4_t
vsum_6789
=
vextq_f32
(
vsum_5678
,
vsum_9101112
,
1
);
float32x4_t
vsum_123_345
=
vaddq_f32
(
vsum_1234
,
vsum_2345
);
vsum_123_345
=
vaddq_f32
(
vsum_123_345
,
vsum_3456
);
float32x4_t
vsum_567_789
=
vaddq_f32
(
vsum_4567
,
vsum_5678
);
vsum_567_789
=
vaddq_f32
(
vsum_567_789
,
vsum_6789
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_123_345
,
2
),
vsum_123_345
,
1
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
1
),
vsum_123_345
,
2
);
vsum_123_345
=
vsetq_lane_f32
(
vgetq_lane_f32
(
vsum_567_789
,
3
),
vsum_123_345
,
3
);
float32x4_t
vrst
=
vmulq_f32
(
vsum_123_345
,
vcoef_6
);
vst1q_f32
(
&
dout_ch
[
cnt
],
vrst
);
cnt
+=
4
;
}
for
(;
w
<
w_even
;
w
+=
2
)
{
float32x4_t
vr0
=
vld1q_f32
(
&
r0
[
w
]);
float32x4_t
vr1
=
vld1q_f32
(
&
r1
[
w
]);
vr0
=
vsetq_lane_f32
(
0.
f
,
vr0
,
3
);
vr1
=
vsetq_lane_f32
(
0.
f
,
vr1
,
3
);
float32x4_t
vsum1
=
vaddq_f32
(
vr0
,
vr1
);
float32x2_t
vsum2
=
vpadd_f32
(
vget_low_f32
(
vsum1
),
vget_high_f32
(
vsum1
));
vsum2
=
vpadd_f32
(
vsum2
,
vsum2
);
float32x2_t
vrst
=
vmul_f32
(
vsum2
,
vget_low_f32
(
vcoef_6
));
dout_ch
[
cnt
]
=
vget_lane_f32
(
vrst
,
0
);
cnt
++
;
}
#else
dr_out
=
dout_ch
;
// + 1;
dr0
=
r0
;
// (r0 + 1);
dr1
=
r1
;
// (r1 + 1);
int
cnt_num
=
w_unroll_size
>>
3
;
int
cnt_num_remain
=
w_unroll_remain
>>
1
;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if
(
cnt_num
>
0
||
cnt_num_remain
>
0
)
{
asm
volatile
(
"cmp %[cnt_num], #0 @cmp cnt_num,0
\n
"
"ble 2f @ble exit
\n
"
"1: @main loop
\n
"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0
\n
"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0
\n
"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1
\n
"
"vadd.f32 q6, q0, q3 @max q0,q0,q2 1234
\n
"
"vadd.f32 q7, q1, q4 @max q1,q1,q3 5678
\n
"
"vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112
\n
"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345
\n
"
"vext.f32 q1, q6, q7, #3 @vext max_4567
\n
"
"vext.f32 q2, q6, q7, #2 @vext max_3456
\n
"
"vext.f32 q3, q7, q8, #1 @vext max_6789
\n
"
"vadd.f32 q4, q6, q0 @add 1234,2345
\n
"
"vadd.f32 q5, q7, q1 @add 5678,4567
\n
"
"vadd.f32 q4, q4, q2 @add 3456,sum1
\n
"
"vadd.f32 q5, q5, q3 @add 6789,sum2
\n
"
"vmov.f32 s17, s18 @mov
\n
"
"vmov.f32 s18, s21 @mov
\n
"
"vmov.f32 s19, s23 @mov
\n
"
"vmul.f32 q4, q4, %q[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,8
\n
"
"sub %[dr1], #8 @add w,8
\n
"
"subs %[cnt_num], #1 @cnt_num--
\n
"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out
\n
"
"bne 1b @bne s3_max_loop_bot
\n
"
"2: @loop
\n
"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0
\n
"
"ble 3f @ble exit
\n
"
"4: @bot loop
\n
"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0
\n
"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1
\n
"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123
\n
"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123
\n
"
"vadd.f32 q0, q0, q1 @add q0,q0,q1
\n
"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1
\n
"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0
\n
"
"vmul.f32 d0, d0, %e[vcoef_6] @mul
\n
"
"sub %[dr0], #8 @add w,6
\n
"
"sub %[dr1], #8 @add w,6
\n
"
"subs %[cnt_num_remain], #1 @cnt_num_remain--
\n
"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out
\n
"
"bne 4b @bne s3_max_loop_bot_1
\n
"
"3: @exit
\n
"
:
[
dr0
]
"+r"
(
dr0
),
[
dr1
]
"+r"
(
dr1
),
[
dr_out
]
"+r"
(
dr_out
),
[
cnt_num
]
"+r"
(
cnt_num
),
[
cnt_num_remain
]
"+r"
(
cnt_num_remain
),
[
vcoef_6
]
"+w"
(
vcoef_6
),
[
vzero
]
"+w"
(
vzero
)
:
"r"
(
dr0
),
"r"
(
dr1
),
"r"
(
dr_out
),
"r"
(
cnt_num
),
"r"
(
cnt_num_remain
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
);
}
#endif
if
(
w_remain
>
0
)
{
// deal with right pad
int
wstart
=
(
w_even
>>
1
)
*
stride
-
padding
;
int
wend
=
std
::
min
(
std
::
min
(
wstart
+
kernel
,
win
+
padding
),
win
);
float
tmp1
=
0.
f
;
float
tmp2
=
exclusive
?
1.0
f
/
(
2.
f
*
(
wend
-
wstart
))
:
coef
;
for
(
int
i
=
wstart
;
i
<
wend
;
i
++
)
{
// only run 1 or 2 times
tmp1
+=
(
r0
[
i
]
+
r1
[
i
]);
}
dout_ch
[
w_even
>>
1
]
=
tmp1
*
tmp2
;
}
}
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/pooling.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
// !pooling fp32 Op
void
pooling_basic
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
bool
global_pooling
,
bool
exclusive
,
bool
adaptive
,
bool
ceil_mode
,
bool
use_quantizer
,
const
std
::
string
&
pooling_type
);
void
pooling_global_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling_global_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling2x2s2_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling2x2s2_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s1p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s1p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s2p1_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s2p1_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
void
pooling3x3s2p0_max
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
);
void
pooling3x3s2p0_avg
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
bool
exclusive
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/scale.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/scale.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
,
float
scale
,
float
bias
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
float32x4_t
vbias
=
vdupq_n_f32
(
bias
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias
,
din0
,
vscale
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias
,
din1
,
vscale
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias
,
din2
,
vscale
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias
,
din3
,
vscale
);
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
+
bias
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
inner_dim
>>
4
;
int
remain
=
inner_dim
%
16
;
int
size
=
inner_dim
*
scale_dim
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
size
;
float
*
dout_ptr_n
=
dout
+
n
*
size
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
scale_dim
;
i
++
)
{
const
float
*
din_ptr
=
din_ptr_n
+
i
*
inner_dim
;
float
*
dout_ptr
=
dout_ptr_n
+
i
*
inner_dim
;
float
scale
=
scale_data
[
i
];
float32x4_t
vscale
=
vdupq_n_f32
(
scale
);
float
bias
=
bias_data
[
i
];
float32x4_t
vbias
=
vdupq_n_f32
(
bias
);
for
(
int
j
=
0
;
j
<
cnt
;
j
++
)
{
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias
,
din0
,
vscale
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias
,
din1
,
vscale
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias
,
din2
,
vscale
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias
,
din3
,
vscale
);
din_ptr
+=
16
;
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
dout_ptr
+=
16
;
}
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
scale
+
bias
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
}
template
<
>
void
scale
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
)
{
int
cnt
=
scale_dim
>>
4
;
int
remain
=
scale_dim
%
16
;
for
(
int
n
=
0
;
n
<
outer_dim
;
n
++
)
{
const
float
*
din_ptr_n
=
din
+
n
*
scale_dim
;
float
*
dout_ptr_n
=
dout
+
n
*
scale_dim
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
int
idx
=
i
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vscale0
=
vld1q_f32
(
scale_ptr
);
float32x4_t
vbias0
=
vld1q_f32
(
bias_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vscale1
=
vld1q_f32
(
scale_ptr
+
4
);
float32x4_t
vbias1
=
vld1q_f32
(
bias_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
vscale2
=
vld1q_f32
(
scale_ptr
+
8
);
float32x4_t
vbias2
=
vld1q_f32
(
bias_ptr
+
8
);
float32x4_t
vsum1
=
vmlaq_f32
(
vbias0
,
din0
,
vscale0
);
float32x4_t
vsum2
=
vmlaq_f32
(
vbias1
,
din1
,
vscale1
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
float32x4_t
vscale3
=
vld1q_f32
(
scale_ptr
+
12
);
float32x4_t
vbias3
=
vld1q_f32
(
bias_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
vsum1
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
float32x4_t
vsum3
=
vmlaq_f32
(
vbias2
,
din2
,
vscale2
);
float32x4_t
vsum4
=
vmlaq_f32
(
vbias3
,
din3
,
vscale3
);
vst1q_f32
(
dout_ptr
+
8
,
vsum3
);
vst1q_f32
(
dout_ptr
+
12
,
vsum4
);
}
int
idx
=
cnt
<<
4
;
const
float
*
din_ptr
=
din_ptr_n
+
idx
;
float
*
dout_ptr
=
dout_ptr_n
+
idx
;
const
float
*
scale_ptr
=
scale_data
+
idx
;
const
float
*
bias_ptr
=
bias_data
+
idx
;
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
*
dout_ptr
=
*
din_ptr
*
(
*
scale_ptr
)
+
(
*
bias_ptr
);
dout_ptr
++
;
din_ptr
++
;
scale_ptr
++
;
bias_ptr
++
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/scale.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
num
,
float
scale
,
float
bias
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
int
inner_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
template
<
typename
T
>
void
scale
(
const
T
*
din
,
T
*
dout
,
int
outer_dim
,
int
scale_dim
,
const
float
*
scale_data
,
const
float
*
bias_data
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/softmax.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/softmax.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
softmax_basic
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
compute_size
;
++
i
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner8_axis4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
3
;
int
remain
=
compute_size
%
8
;
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
8
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float32x4_t
vdata01
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vdata11
=
vld1q_f32
(
din_ptr1
+
4
);
float32x4_t
vdata21
=
vld1q_f32
(
din_ptr2
+
4
);
float32x4_t
vdata31
=
vld1q_f32
(
din_ptr3
+
4
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float32x4_t
vmax11
=
vmaxq_f32
(
vdata01
,
vdata11
);
float32x4_t
vmax21
=
vmaxq_f32
(
vdata21
,
vdata31
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
float32x4_t
vmax_1
=
vmaxq_f32
(
vmax11
,
vmax21
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum01
=
exp_ps
(
vsubq_f32
(
vdata01
,
vmax_1
));
float32x4_t
vsum11
=
exp_ps
(
vsubq_f32
(
vdata11
,
vmax_1
));
float32x4_t
vsum21
=
exp_ps
(
vsubq_f32
(
vdata21
,
vmax_1
));
float32x4_t
vsum31
=
exp_ps
(
vsubq_f32
(
vdata31
,
vmax_1
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum_11
=
vaddq_f32
(
vsum01
,
vsum11
);
float32x4_t
vsum_21
=
vaddq_f32
(
vsum21
,
vsum31
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vsum111
=
vaddq_f32
(
vsum_11
,
vsum_21
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
float32x4_t
vinf1
=
div_ps
(
vone
,
vsum111
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vsum01
=
vmulq_f32
(
vsum01
,
vinf1
);
vsum11
=
vmulq_f32
(
vsum11
,
vinf1
);
vsum21
=
vmulq_f32
(
vsum21
,
vinf1
);
vsum31
=
vmulq_f32
(
vsum31
,
vinf1
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
vst1q_f32
(
dout_ptr0
+
4
,
vsum01
);
vst1q_f32
(
dout_ptr1
+
4
,
vsum11
);
vst1q_f32
(
dout_ptr2
+
4
,
vsum21
);
vst1q_f32
(
dout_ptr3
+
4
,
vsum31
);
}
int
i
=
cmp_cnt
*
8
;
if
(
remain
>
4
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
i
+=
4
;
}
for
(;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner4_axis4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
2
;
int
remain
=
compute_size
%
4
;
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
4
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get max axis_size == 4
const
float
*
din_ptr
=
din
+
real_index
;
const
float
*
din_ptr1
=
din_ptr
+
inner_num
;
const
float
*
din_ptr2
=
din_ptr1
+
inner_num
;
const
float
*
din_ptr3
=
din_ptr2
+
inner_num
;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr1
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr2
);
float32x4_t
vdata3
=
vld1q_f32
(
din_ptr3
);
float
*
dout_ptr0
=
dout
+
real_index
;
float
*
dout_ptr1
=
dout_ptr0
+
inner_num
;
float32x4_t
vmax1
=
vmaxq_f32
(
vdata0
,
vdata1
);
float32x4_t
vmax2
=
vmaxq_f32
(
vdata2
,
vdata3
);
float
*
dout_ptr2
=
dout_ptr1
+
inner_num
;
float
*
dout_ptr3
=
dout_ptr2
+
inner_num
;
float32x4_t
vmax
=
vmaxq_f32
(
vmax1
,
vmax2
);
// sub, exp and sum
float32x4_t
vsum0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
float32x4_t
vsum1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax
));
float32x4_t
vsum3
=
exp_ps
(
vsubq_f32
(
vdata3
,
vmax
));
float32x4_t
vsum_1
=
vaddq_f32
(
vsum0
,
vsum1
);
float32x4_t
vsum_2
=
vaddq_f32
(
vsum2
,
vsum3
);
float32x4_t
vsum
=
vaddq_f32
(
vsum_1
,
vsum_2
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
vsum0
=
vmulq_f32
(
vsum0
,
vinf
);
vsum1
=
vmulq_f32
(
vsum1
,
vinf
);
vsum2
=
vmulq_f32
(
vsum2
,
vinf
);
vsum3
=
vmulq_f32
(
vsum3
,
vinf
);
vst1q_f32
(
dout_ptr0
,
vsum0
);
vst1q_f32
(
dout_ptr1
,
vsum1
);
vst1q_f32
(
dout_ptr2
,
vsum2
);
vst1q_f32
(
dout_ptr3
,
vsum3
);
}
int
i
=
cmp_cnt
*
8
;
for
(;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner8
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
3
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
8
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
const
float
*
din_ptr
=
din
+
real_index
;
float32x4_t
vmax
=
vld1q_f32
(
din_ptr
);
float32x4_t
vmax2
=
vld1q_f32
(
din_ptr
+
4
);
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
din_ptr
+=
inner_num
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr
+
4
);
vmax
=
vmaxq_f32
(
vmax
,
vdata
);
vmax2
=
vmaxq_f32
(
vmax2
,
vdata2
);
}
// sub, exp and sum
din_ptr
=
din
+
real_index
;
float
*
dout_ptr
=
dout
+
real_index
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata2
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
vsum
=
exp_ps
(
vsubq_f32
(
vdata
,
vmax
));
float32x4_t
vsum2
=
exp_ps
(
vsubq_f32
(
vdata2
,
vmax2
));
din_ptr
+=
inner_num
;
vst1q_f32
(
dout_ptr
,
vsum
);
vst1q_f32
(
dout_ptr
+
4
,
vsum2
);
dout_ptr
+=
inner_num
;
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
din_ptr
+
4
);
vdata0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
vdata1
=
exp_ps
(
vsubq_f32
(
vdata1
,
vmax2
));
din_ptr
+=
inner_num
;
vsum
=
vaddq_f32
(
vsum
,
vdata0
);
vsum2
=
vaddq_f32
(
vsum2
,
vdata1
);
vst1q_f32
(
dout_ptr
,
vdata0
);
vst1q_f32
(
dout_ptr
+
4
,
vdata1
);
dout_ptr
+=
inner_num
;
}
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
float32x4_t
vinf2
=
div_ps
(
vone
,
vsum2
);
dout_ptr
=
dout
+
real_index
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
dout_ptr
);
float32x4_t
vdata1
=
vld1q_f32
(
dout_ptr
+
4
);
vdata0
=
vmulq_f32
(
vdata0
,
vinf
);
vdata1
=
vmulq_f32
(
vdata1
,
vinf2
);
vst1q_f32
(
dout_ptr
,
vdata0
);
vst1q_f32
(
dout_ptr
+
4
,
vdata1
);
dout_ptr
+=
inner_num
;
}
}
for
(
int
i
=
cmp_cnt
*
8
;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner4
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
)
{
int
compute_size
=
inner_num
*
outer_num
;
int
cmp_cnt
=
compute_size
>>
2
;
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
cmp_cnt
;
++
c
)
{
int
i
=
c
*
4
;
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// float max_data = din[real_index];
const
float
*
din_ptr
=
din
+
real_index
;
float32x4_t
vmax
=
vld1q_f32
(
din_ptr
);
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
din_ptr
+=
inner_num
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
vmax
=
vmaxq_f32
(
vmax
,
vdata
);
}
// sub, exp and sum
din_ptr
=
din
+
real_index
;
float
*
dout_ptr
=
dout
+
real_index
;
float32x4_t
vdata
=
vld1q_f32
(
din_ptr
);
float32x4_t
vsum
=
exp_ps
(
vsubq_f32
(
vdata
,
vmax
));
din_ptr
+=
inner_num
;
vst1q_f32
(
dout_ptr
,
vsum
);
dout_ptr
+=
inner_num
;
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
// real_index += inner_num;
float32x4_t
vdata0
=
vld1q_f32
(
din_ptr
);
vdata0
=
exp_ps
(
vsubq_f32
(
vdata0
,
vmax
));
din_ptr
+=
inner_num
;
vsum
=
vaddq_f32
(
vsum
,
vdata0
);
vst1q_f32
(
dout_ptr
,
vdata0
);
dout_ptr
+=
inner_num
;
}
float32x4_t
vone
=
vdupq_n_f32
(
1.0
f
);
float32x4_t
vinf
=
div_ps
(
vone
,
vsum
);
dout_ptr
=
dout
+
real_index
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
float32x4_t
vdata0
=
vld1q_f32
(
dout_ptr
);
vdata0
=
vmulq_f32
(
vdata0
,
vinf
);
vst1q_f32
(
dout_ptr
,
vdata0
);
dout_ptr
+=
inner_num
;
}
}
for
(
int
i
=
cmp_cnt
*
4
;
i
<
compute_size
;
i
++
)
{
int
idx_inner
=
i
%
inner_num
;
int
idx_outer
=
(
i
/
inner_num
)
*
axis_size
;
int
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
float
max_data
=
din
[
real_index
];
// get max
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
max_data
=
din
[
real_index
]
>
max_data
?
din
[
real_index
]
:
max_data
;
}
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// sub, exp and sum
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
float
sum_data
=
dout
[
real_index
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
real_index
+=
inner_num
;
dout
[
real_index
]
=
expf
(
din
[
real_index
]
-
max_data
);
sum_data
+=
dout
[
real_index
];
}
float
sum_inv
=
1.
f
/
sum_data
;
real_index
=
idx_outer
*
inner_num
+
idx_inner
;
// get softmax result
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout
[
real_index
]
*=
sum_inv
;
real_index
+=
inner_num
;
}
}
}
template
<
>
void
softmax_inner1_large_axis
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
outer_size
,
const
int
axis_size
)
{
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
outer_size
;
++
i
)
{
const
float
*
din_ptr
=
din
+
i
*
axis_size
;
float
*
dout_ptr
=
dout
+
i
*
axis_size
;
const
float
*
din_max_ptr
=
din_ptr
;
int
nn
=
axis_size
>>
2
;
// get max
float32x4_t
vmax
=
vld1q_f32
(
din_max_ptr
);
din_max_ptr
+=
4
;
int
j
=
1
;
for
(;
j
<
nn
;
++
j
)
{
vmax
=
vmaxq_f32
(
vmax
,
vld1q_f32
(
din_max_ptr
));
din_max_ptr
+=
4
;
}
float32x2_t
vhmax
=
vmax_f32
(
vget_high_f32
(
vmax
),
vget_low_f32
(
vmax
));
float
max_data
=
std
::
max
(
vget_lane_f32
(
vhmax
,
0
),
vget_lane_f32
(
vhmax
,
1
));
for
(
j
=
4
*
j
;
j
<
axis_size
;
++
j
)
{
max_data
=
std
::
max
(
max_data
,
din_max_ptr
[
0
]);
din_max_ptr
++
;
}
// sub, exp and sum
const
float
*
din_sum_ptr
=
din_ptr
;
float
*
dout_sum_ptr
=
dout_ptr
;
vmax
=
vdupq_n_f32
(
max_data
);
float32x4_t
vsub_exp
=
exp_ps
(
vsubq_f32
(
vld1q_f32
(
din_sum_ptr
),
vmax
));
float32x4_t
vsum
=
vsub_exp
;
vst1q_f32
(
dout_sum_ptr
,
vsub_exp
);
din_sum_ptr
+=
4
;
dout_sum_ptr
+=
4
;
j
=
1
;
for
(;
j
<
nn
;
++
j
)
{
vsub_exp
=
exp_ps
(
vsubq_f32
(
vld1q_f32
(
din_sum_ptr
),
vmax
));
vst1q_f32
(
dout_sum_ptr
,
vsub_exp
);
vsum
=
vaddq_f32
(
vsum
,
vsub_exp
);
din_sum_ptr
+=
4
;
dout_sum_ptr
+=
4
;
}
float32x2_t
vhsum
=
vadd_f32
(
vget_high_f32
(
vsum
),
vget_low_f32
(
vsum
));
float
sum_data
=
vget_lane_f32
(
vhsum
,
0
)
+
vget_lane_f32
(
vhsum
,
1
);
for
(
j
=
4
*
j
;
j
<
axis_size
;
++
j
)
{
dout_sum_ptr
[
0
]
=
expf
(
din_sum_ptr
[
0
]
-
max_data
);
sum_data
+=
dout_sum_ptr
[
0
];
din_sum_ptr
++
;
dout_sum_ptr
++
;
}
float
sum_inv
=
1.
f
/
sum_data
;
float
*
dout_res_ptr
=
dout_ptr
;
float32x4_t
vinv
=
vdupq_n_f32
(
sum_inv
);
// get softmax result
j
=
0
;
for
(;
j
<
nn
;
++
j
)
{
float32x4_t
vout
=
vld1q_f32
(
dout_res_ptr
);
float32x4_t
vres
=
vmulq_f32
(
vout
,
vinv
);
vst1q_f32
(
dout_res_ptr
,
vres
);
dout_res_ptr
+=
4
;
}
for
(
j
=
nn
*
4
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
*=
sum_inv
;
}
}
}
template
<
>
void
softmax_inner1_small_axis
<
float
>
(
const
float
*
din
,
float
*
dout
,
const
int
outer_size
,
const
int
axis_size
)
{
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
outer_size
;
++
i
)
{
const
float
*
din_ptr
=
din
+
i
*
axis_size
;
float
*
dout_ptr
=
dout
+
i
*
axis_size
;
// get max
float
max_data
=
din_ptr
[
0
];
for
(
int
j
=
1
;
j
<
axis_size
;
++
j
)
{
max_data
=
std
::
max
(
max_data
,
din_ptr
[
j
]);
}
// sub, exp and sum
float
sum_data
=
0.
f
;
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
=
expf
(
din_ptr
[
j
]
-
max_data
);
sum_data
+=
dout_ptr
[
j
];
}
float
sum_inv
=
1.
f
/
sum_data
;
for
(
int
j
=
0
;
j
<
axis_size
;
++
j
)
{
dout_ptr
[
j
]
*=
sum_inv
;
}
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/softmax.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
softmax_basic
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner8_axis4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner4_axis4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner8
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner4
(
const
T
*
din
,
T
*
dout
,
const
int
axis_size
,
const
int
inner_num
,
const
int
outer_num
);
template
<
typename
T
>
void
softmax_inner1_large_axis
(
const
T
*
din
,
T
*
dout
,
const
int
outer_size
,
const
int
axis_size
);
template
<
typename
T
>
void
softmax_inner1_small_axis
(
const
T
*
din
,
T
*
dout
,
const
int
outer_size
,
const
int
axis_size
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/split.cc
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/split.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
>
void
split_cpy
<
float
>
(
const
float
*
din
,
float
*
dout
,
int
num
)
{
int
cnt
=
num
>>
4
;
int
remain
=
num
%
16
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
cnt
;
i
++
)
{
const
float
*
din_ptr
=
din
+
(
i
<<
4
);
float
*
dout_ptr
=
dout
+
(
i
<<
4
);
float32x4_t
din0
=
vld1q_f32
(
din_ptr
);
float32x4_t
din1
=
vld1q_f32
(
din_ptr
+
4
);
float32x4_t
din2
=
vld1q_f32
(
din_ptr
+
8
);
float32x4_t
din3
=
vld1q_f32
(
din_ptr
+
12
);
vst1q_f32
(
dout_ptr
,
din0
);
vst1q_f32
(
dout_ptr
+
4
,
din1
);
vst1q_f32
(
dout_ptr
+
8
,
din2
);
vst1q_f32
(
dout_ptr
+
12
,
din3
);
}
if
(
remain
>
0
)
{
const
float
*
din_ptr
=
din
+
(
cnt
<<
4
);
float
*
dout_ptr
=
dout
+
(
cnt
<<
4
);
for
(
int
i
=
0
;
i
<
remain
;
i
++
)
{
*
dout_ptr
=
*
din_ptr
;
dout_ptr
++
;
din_ptr
++
;
}
}
}
template
<
>
void
split
<
float
>
(
const
float
*
din
,
const
std
::
vector
<
lite
::
Tensor
*>&
dout
,
const
int
axis
,
const
std
::
vector
<
int
>&
in_strides
)
{
int
input_offset
=
0
;
for
(
auto
out
:
dout
)
{
auto
out_dim
=
out
->
dims
();
std
::
vector
<
int
>
out_strides
(
out_dim
.
size
());
out_strides
[
out_dim
.
size
()
-
1
]
=
out_dim
[
out_dim
.
size
()
-
1
];
for
(
int
i
=
out_dim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
out_strides
[
i
]
=
out_strides
[
i
+
1
]
*
out_dim
[
i
];
}
float
*
out_data
=
out
->
mutable_data
<
float
>
();
int
before
=
out_strides
[
0
]
/
out_strides
[
axis
];
int
in_after
=
in_strides
[
axis
];
int
out_after
=
out_strides
[
axis
];
for
(
int
i
=
0
;
i
<
before
;
++
i
)
{
split_cpy
(
din
+
input_offset
+
i
*
in_after
,
out_data
+
i
*
out_after
,
out_after
);
}
input_offset
+=
out_strides
[
axis
];
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/arm/math/split.h
已删除
100644 → 0
浏览文件 @
0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
template
<
typename
T
>
void
split_cpy
(
const
T
*
din
,
T
*
dout
,
int
num
);
template
<
typename
T
>
void
split
(
const
T
*
din
,
const
std
::
vector
<
lite
::
Tensor
*>&
dout
,
const
int
axis
,
const
std
::
vector
<
int
>&
in_strides
);
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录