Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
16a0bd75
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
16a0bd75
编写于
7月 17, 2019
作者:
Y
Yanzhan Yang
提交者:
GitHub
7月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add faster depthwise implementations (#1747)
* add faster depthwise implementations * fix style
上级
734f22ff
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
8046 addition
and
3 deletion
+8046
-3
src/operators/kernel/central-arm-func/conv_arm_func.cpp
src/operators/kernel/central-arm-func/conv_arm_func.cpp
+60
-0
src/operators/math/depthwise/faster_depthwise_conv3x3.h
src/operators/math/depthwise/faster_depthwise_conv3x3.h
+39
-0
src/operators/math/depthwise/faster_depthwise_conv3x3p0.cpp
src/operators/math/depthwise/faster_depthwise_conv3x3p0.cpp
+3631
-0
src/operators/math/depthwise/faster_depthwise_conv3x3p1.cpp
src/operators/math/depthwise/faster_depthwise_conv3x3p1.cpp
+4312
-0
test/net/test_benchmark.cpp
test/net/test_benchmark.cpp
+4
-3
未找到文件。
src/operators/kernel/central-arm-func/conv_arm_func.cpp
浏览文件 @
16a0bd75
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <vector>
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/im2col.h"
...
...
@@ -211,6 +212,65 @@ void DepthwiseConv3x3(const ConvParam<CPU> ¶m) {
}
}
template
<
>
void
DepthwiseConv3x3
<
float
,
float
>
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
filter
=
param
.
Filter
();
const
std
::
vector
<
int
>
&
paddings
=
param
.
Paddings
();
const
std
::
vector
<
int
>
&
strides
=
param
.
Strides
();
const
int
batch_size
=
input
->
dims
()[
0
];
Tensor
*
output
=
param
.
Output
();
output
->
mutable_data
<
float
>
();
if
(
paddings
.
size
()
==
2
&&
paddings
[
0
]
==
paddings
[
1
]
&&
strides
.
size
()
==
2
&&
strides
[
0
]
==
strides
[
1
])
{
int
pad
=
paddings
[
0
];
int
stride
=
strides
[
0
];
const
float
*
din
=
input
->
data
<
float
>
();
float
*
dout
=
output
->
mutable_data
<
float
>
();
const
float
*
weights
=
filter
->
data
<
float
>
();
const
float
*
bias
=
nullptr
;
const
int
num
=
input
->
dims
()[
0
];
const
int
chin
=
input
->
dims
()[
1
];
const
int
hin
=
input
->
dims
()[
2
];
const
int
win
=
input
->
dims
()[
3
];
const
int
chout
=
output
->
dims
()[
1
];
const
int
hout
=
output
->
dims
()[
2
];
const
int
wout
=
output
->
dims
()[
3
];
bool
flag_relu
=
false
;
bool
flag_bias
=
bias
!=
nullptr
;
if
(
pad
==
0
&&
hin
>
2
)
{
math
::
depthwise
::
conv_depthwise_3x3p0
(
din
,
dout
,
num
,
chout
,
hout
,
wout
,
chin
,
hin
,
win
,
weights
,
bias
,
stride
,
flag_bias
,
flag_relu
);
}
else
if
(
pad
==
1
)
{
math
::
depthwise
::
conv_depthwise_3x3p1
(
din
,
dout
,
num
,
chout
,
hout
,
wout
,
chin
,
hin
,
win
,
weights
,
bias
,
stride
,
flag_bias
,
flag_relu
);
}
else
{
GemmConv
<
float
,
float
>
(
param
);
}
}
else
{
if
(
strides
[
0
]
==
1
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
math
::
DepthwiseConv3x3S1
<
float
,
float
>
(
in_batch
,
*
filter
,
paddings
,
&
out_batch
);
}
}
else
if
(
strides
[
0
]
==
2
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
math
::
DepthwiseConv3x3S2
<
float
,
float
>
(
in_batch
,
*
filter
,
paddings
,
&
out_batch
);
}
}
else
{
GemmConv
<
float
,
float
>
(
param
);
}
}
}
template
<
typename
Itype
,
typename
Otype
>
void
DepthwiseConv5x5
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
...
...
src/operators/math/depthwise/faster_depthwise_conv3x3.h
0 → 100644
浏览文件 @
16a0bd75
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
namespace
depthwise
{
void
conv_depthwise_3x3p0
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
ch_out
,
int
h_out
,
int
w_out
,
int
ch_in
,
int
h_in
,
int
w_in
,
const
float
*
weights
,
const
float
*
bias
,
int
stride
,
bool
flag_bias
,
bool
flag_relu
);
void
conv_depthwise_3x3p1
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
ch_out
,
int
h_out
,
int
w_out
,
int
ch_in
,
int
h_in
,
int
w_in
,
const
float
*
weights
,
const
float
*
bias
,
int
stride
,
bool
flag_bias
,
bool
flag_relu
);
}
// namespace depthwise
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
src/operators/math/depthwise/faster_depthwise_conv3x3p0.cpp
0 → 100644
浏览文件 @
16a0bd75
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#include "framework/context.h"
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
namespace
depthwise
{
void
conv_depthwise_3x3s1p0_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s1p0_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s2p0_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s2p0_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s1p0_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s1p0_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s2p0_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s2p0_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3p0
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
ch_out
,
int
h_out
,
int
w_out
,
int
ch_in
,
int
h_in
,
int
w_in
,
const
float
*
weights
,
const
float
*
bias
,
int
stride
,
bool
flag_bias
,
bool
flag_relu
)
{
if
(
stride
==
1
)
{
if
(
flag_relu
)
{
if
(
w_in
>
5
)
{
conv_depthwise_3x3s1p0_bias_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s1p0_bias_s_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
else
{
if
(
w_in
>
5
)
{
conv_depthwise_3x3s1p0_bias
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s1p0_bias_s
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
}
else
{
//! stride = 2
if
(
flag_relu
)
{
if
(
w_in
>
8
)
{
conv_depthwise_3x3s2p0_bias_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s2p0_bias_s_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
else
{
if
(
w_in
>
8
)
{
conv_depthwise_3x3s2p0_bias
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s2p0_bias_s
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
// 4line
void
conv_depthwise_3x3s1p0_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
//! for 4x6 convolution window
const
unsigned
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
int
w_stride
=
9
;
int
tile_w
=
w_out
>>
2
;
int
remain
=
w_out
%
4
;
unsigned
int
size_pad_right
=
(
unsigned
int
)(
6
+
(
tile_w
<<
2
)
-
w_in
);
const
int
remian_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
uint32x4_t
vmask_rp1
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_rp2
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
+
4
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_result
=
vcgtq_s32
(
vdupq_n_s32
(
remain
),
vld1q_s32
(
remian_idx
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
unsigned
int
rmask
[
4
];
vst1q_u32
(
rmask
,
vmask_result
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
#ifdef __aarch64__
for
(
int
c
=
0
;
c
<
ch_in
;
c
++
)
{
float
*
dout_ptr
=
dout_batch
+
c
*
size_out_channel
;
const
float
*
din_ch_ptr
=
din_batch
+
c
*
size_in_channel
;
float
bias_val
=
flag_bias
?
bias
[
c
]
:
0.
f
;
float
vbias
[
4
]
=
{
bias_val
,
bias_val
,
bias_val
,
bias_val
};
const
float
*
wei_ptr
=
weights
+
c
*
w_stride
;
float32x4_t
wr0
=
vld1q_f32
(
wei_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
wei_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
wei_ptr
+
6
);
// wr0 = vsetq_lane_f32(0.f, wr0, 3);
// wr1 = vsetq_lane_f32(0.f, wr1, 3);
// wr2 = vsetq_lane_f32(0.f, wr2, 3);
float
*
doutr0
=
dout_ptr
;
float
*
doutr1
=
doutr0
+
w_out
;
float
*
doutr2
=
doutr1
+
w_out
;
float
*
doutr3
=
doutr2
+
w_out
;
const
float
*
dr0
=
din_ch_ptr
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
dr5
=
dr4
+
w_in
;
const
float
*
din_ptr0
=
dr0
;
const
float
*
din_ptr1
=
dr1
;
const
float
*
din_ptr2
=
dr2
;
const
float
*
din_ptr3
=
dr3
;
const
float
*
din_ptr4
=
dr4
;
const
float
*
din_ptr5
=
dr5
;
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
4
)
{
//! process top pad pad_h = 1
din_ptr0
=
dr0
;
din_ptr1
=
dr1
;
din_ptr2
=
dr2
;
din_ptr3
=
dr3
;
din_ptr4
=
dr4
;
din_ptr5
=
dr5
;
doutr0
=
dout_ptr
;
doutr1
=
doutr0
+
w_out
;
doutr2
=
doutr1
+
w_out
;
doutr3
=
doutr2
+
w_out
;
dr0
=
dr4
;
dr1
=
dr5
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
dr5
=
dr4
+
w_in
;
//! process bottom pad
if
(
i
+
5
>=
h_in
)
{
switch
(
i
+
5
-
h_in
)
{
case
5
:
din_ptr1
=
zero_ptr
;
case
4
:
din_ptr2
=
zero_ptr
;
case
3
:
din_ptr3
=
zero_ptr
;
case
2
:
din_ptr4
=
zero_ptr
;
case
1
:
din_ptr5
=
zero_ptr
;
case
0
:
din_ptr5
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
4
>
h_out
)
{
switch
(
i
+
4
-
h_out
)
{
case
3
:
doutr1
=
write_ptr
;
case
2
:
doutr2
=
write_ptr
;
case
1
:
doutr3
=
write_ptr
;
default:
break
;
}
}
int
cnt
=
tile_w
;
asm
volatile
(
"PRFM PLDL1KEEP, [%[din_ptr0]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr1]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr2]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr3]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr4]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr5]]
\n
"
"movi v21.4s, #0x0
\n
"
/* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
// mid
// "cmp %[cnt], #1 \n"
// "blt 5f \n"
"4:
\n
"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r5
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"bne 4b
\n
"
// right
"5:
\n
"
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
"ld1 {v18.4s, v19.4s}, [%[vmask]]
\n
"
"ld1 {v22.4s}, [%[doutr0]]
\n
"
"ld1 {v23.4s}, [%[doutr1]]
\n
"
"ld1 {v24.4s}, [%[doutr2]]
\n
"
"ld1 {v25.4s}, [%[doutr3]]
\n
"
"bif v0.16b, %[vzero].16b, v18.16b
\n
"
"bif v1.16b, %[vzero].16b, v19.16b
\n
"
"bif v2.16b, %[vzero].16b, v18.16b
\n
"
"bif v3.16b, %[vzero].16b, v19.16b
\n
"
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v4.16b, %[vzero].16b, v18.16b
\n
"
"bif v5.16b, %[vzero].16b, v19.16b
\n
"
"bif v6.16b, %[vzero].16b, v18.16b
\n
"
"bif v7.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s, v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v8.16b, %[vzero].16b, v18.16b
\n
"
"bif v9.16b, %[vzero].16b, v19.16b
\n
"
"bif v10.16b, %[vzero].16b, v18.16b
\n
"
"bif v11.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v18.4s}, [%[rmask]]
\n
"
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v12.16b, v22.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v13.16b, v23.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v14.16b, v24.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"bif v15.16b, v25.16b, v18.16b
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
// end
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
}
#else
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
din0_ptr
=
nullptr
;
const
float
*
din1_ptr
=
nullptr
;
const
float
*
din2_ptr
=
nullptr
;
const
float
*
din3_ptr
=
nullptr
;
float
*
doutr0
=
nullptr
;
float
*
doutr1
=
nullptr
;
float
*
ptr_zero
=
const_cast
<
float
*>
(
zero
);
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
2
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
doutr0
=
dout_channel
;
doutr1
=
dout_channel
+
w_out
;
dr0
=
dr2
;
dr1
=
dr3
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
//! process bottom pad
if
(
i
+
3
>=
h_in
)
{
switch
(
i
+
3
-
h_in
)
{
case
3
:
din1_ptr
=
zero_ptr
;
case
2
:
din2_ptr
=
zero_ptr
;
case
1
:
din3_ptr
=
zero_ptr
;
case
0
:
din3_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
2
>
h_out
)
{
doutr1
=
write_ptr
;
}
int
cnt
=
tile_w
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
asm
volatile
(
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// mid
"1: @ right pad entry
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point
\n
"
// right
"3: @ right pad entry
\n
"
"cmp %[remain], #1 @ check whether has "
"mid cols
\n
"
"blt 0f @ jump to main loop start "
"point
\n
"
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d31}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d19}, [%[rmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[rmask]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d8, d16, d19 @ bit select, deal with right pad
\n
"
"vbif d9, d17, d23 @ bit select, deal with right pad
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad
\n
"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
din3_ptr
]
"+r"
(
din3_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
dout_channel
+=
2
*
w_out
;
}
//! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
// w_in > 7
void
conv_depthwise_3x3s2p0_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
int
tile_w
=
w_out
>>
2
;
int
cnt_remain
=
w_out
%
4
;
unsigned
int
size_right_remain
=
(
unsigned
int
)(
w_in
-
(
tile_w
<<
3
));
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
uint32x4_t
wmask
=
vcgtq_s32
(
vdupq_n_s32
(
cnt_remain
),
vld1q_s32
(
out_pad_idx
));
// 0 1 2 3
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
unsigned
int
dmask
[
12
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
vst1q_u32
(
dmask
+
8
,
wmask
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
wbias
;
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
bias_c
=
bias
[
i
];
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
const
float
*
din3_ptr
=
dr3
;
const
float
*
din4_ptr
=
dr4
;
float
*
doutr0
=
dout_channel
;
float
*
doutr0_ptr
=
nullptr
;
float
*
doutr1_ptr
=
nullptr
;
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
2
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
din4_ptr
=
dr4
;
doutr0_ptr
=
doutr0
;
doutr1_ptr
=
doutr0
+
w_out
;
dr0
=
dr4
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
//! process bottom pad
if
(
i
+
4
>=
h_in
)
{
switch
(
i
+
4
-
h_in
)
{
case
4
:
din1_ptr
=
zero_ptr
;
case
3
:
din2_ptr
=
zero_ptr
;
case
2
:
din3_ptr
=
zero_ptr
;
case
1
:
din4_ptr
=
zero_ptr
;
case
0
:
din4_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process output pad
if
(
i
+
2
>
h_out
)
{
doutr1_ptr
=
write_ptr
;
}
int
cnt
=
tile_w
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"prfm pldl1keep, [%[inptr0]]
\n
"
"prfm pldl1keep, [%[inptr1]]
\n
"
"prfm pldl1keep, [%[inptr2]]
\n
"
"prfm pldl1keep, [%[inptr3]]
\n
"
"prfm pldl1keep, [%[inptr4]]
\n
"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
// mid
"2:
\n
"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v13.4s
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 4f
\n
"
"3:
\n
"
"bif v0.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]]
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
"ld1 {v1.4s}, [%[outptr1]]
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"bif v16.16b, v0.16b, %[wmask].16b
\n
"
// pipei
"fadd v17.4s, v17.4s, v13.4s
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"bif v17.16b, v1.16b, %[wmask].16b
\n
"
// pipei
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"4:
\n
"
:
[
inptr0
]
"+r"
(
din0_ptr
),
[
inptr1
]
"+r"
(
din1_ptr
),
[
inptr2
]
"+r"
(
din2_ptr
),
[
inptr3
]
"+r"
(
din3_ptr
),
[
inptr4
]
"+r"
(
din4_ptr
),
[
outptr0
]
"+r"
(
doutr0_ptr
),
[
outptr1
]
"+r"
(
doutr1_ptr
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
remain
]
"r"
(
cnt_remain
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
wmask
]
"w"
(
wmask
),
[
vbias
]
"w"
(
wbias
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
);
doutr0
=
doutr0
+
2
*
w_out
;
}
#else
for
(
int
i
=
0
;
i
<
h_out
;
i
++
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
doutr0_ptr
=
doutr0
;
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
//! process bottom pad
if
(
i
+
2
>
h_in
)
{
switch
(
i
+
2
-
h_in
)
{
case
2
:
din1_ptr
=
zero_ptr
;
case
1
:
din2_ptr
=
zero_ptr
;
default:
break
;
}
}
int
cnt
=
tile_w
;
unsigned
int
*
mask_ptr
=
dmask
;
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"vmov.u32 q9, #0
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1
\n
"
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
// mid
"2:
\n
"
"vext.32 q6, q10, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"subs %[cnt], #1
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 3f
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output
\n
"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask
\n
"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vbif.f32 q3, q10, q11 @ write mask
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"3:
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
outptr
]
"+r"
(
doutr0_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
remain
]
"r"
(
cnt_remain
),
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
doutr0
=
doutr0
+
w_out
;
}
#endif
}
}
}
// 4line
void
conv_depthwise_3x3s1p0_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
//! for 4x6 convolution window
const
unsigned
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
int
w_stride
=
9
;
int
tile_w
=
w_out
>>
2
;
int
remain
=
w_out
%
4
;
unsigned
int
size_pad_right
=
(
unsigned
int
)(
6
+
(
tile_w
<<
2
)
-
w_in
);
const
int
remian_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
uint32x4_t
vmask_rp1
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_rp2
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
+
4
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_result
=
vcgtq_s32
(
vdupq_n_s32
(
remain
),
vld1q_s32
(
remian_idx
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
unsigned
int
rmask
[
4
];
vst1q_u32
(
rmask
,
vmask_result
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
#ifdef __aarch64__
for
(
int
c
=
0
;
c
<
ch_in
;
c
++
)
{
float
*
dout_ptr
=
dout_batch
+
c
*
size_out_channel
;
const
float
*
din_ch_ptr
=
din_batch
+
c
*
size_in_channel
;
float
bias_val
=
flag_bias
?
bias
[
c
]
:
0.
f
;
float
vbias
[
4
]
=
{
bias_val
,
bias_val
,
bias_val
,
bias_val
};
const
float
*
wei_ptr
=
weights
+
c
*
w_stride
;
float32x4_t
wr0
=
vld1q_f32
(
wei_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
wei_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
wei_ptr
+
6
);
// wr0 = vsetq_lane_f32(0.f, wr0, 3);
// wr1 = vsetq_lane_f32(0.f, wr1, 3);
// wr2 = vsetq_lane_f32(0.f, wr2, 3);
float
*
doutr0
=
dout_ptr
;
float
*
doutr1
=
doutr0
+
w_out
;
float
*
doutr2
=
doutr1
+
w_out
;
float
*
doutr3
=
doutr2
+
w_out
;
const
float
*
dr0
=
din_ch_ptr
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
dr5
=
dr4
+
w_in
;
const
float
*
din_ptr0
=
dr0
;
const
float
*
din_ptr1
=
dr1
;
const
float
*
din_ptr2
=
dr2
;
const
float
*
din_ptr3
=
dr3
;
const
float
*
din_ptr4
=
dr4
;
const
float
*
din_ptr5
=
dr5
;
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
4
)
{
//! process top pad pad_h = 1
din_ptr0
=
dr0
;
din_ptr1
=
dr1
;
din_ptr2
=
dr2
;
din_ptr3
=
dr3
;
din_ptr4
=
dr4
;
din_ptr5
=
dr5
;
doutr0
=
dout_ptr
;
doutr1
=
doutr0
+
w_out
;
doutr2
=
doutr1
+
w_out
;
doutr3
=
doutr2
+
w_out
;
dr0
=
dr4
;
dr1
=
dr5
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
dr5
=
dr4
+
w_in
;
//! process bottom pad
if
(
i
+
5
>=
h_in
)
{
switch
(
i
+
5
-
h_in
)
{
case
5
:
din_ptr1
=
zero_ptr
;
case
4
:
din_ptr2
=
zero_ptr
;
case
3
:
din_ptr3
=
zero_ptr
;
case
2
:
din_ptr4
=
zero_ptr
;
case
1
:
din_ptr5
=
zero_ptr
;
case
0
:
din_ptr5
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
4
>
h_out
)
{
switch
(
i
+
4
-
h_out
)
{
case
3
:
doutr1
=
write_ptr
;
case
2
:
doutr2
=
write_ptr
;
case
1
:
doutr3
=
write_ptr
;
default:
break
;
}
}
int
cnt
=
tile_w
;
asm
volatile
(
"PRFM PLDL1KEEP, [%[din_ptr0]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr1]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr2]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr3]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr4]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr5]]
\n
"
"movi v21.4s, #0x0
\n
"
/* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
// mid
"4:
\n
"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s
\n
"
/* relu */
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
// r4
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v13.4s, v13.4s, %[vzero].4s
\n
"
/* relu */
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
// r5
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s
\n
"
/* relu */
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmax v15.4s, v15.4s, %[vzero].4s
\n
"
/* relu */
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"bne 4b
\n
"
// right
"5:
\n
"
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
"ld1 {v18.4s, v19.4s}, [%[vmask]]
\n
"
"ld1 {v22.4s}, [%[doutr0]]
\n
"
"ld1 {v23.4s}, [%[doutr1]]
\n
"
"ld1 {v24.4s}, [%[doutr2]]
\n
"
"ld1 {v25.4s}, [%[doutr3]]
\n
"
"bif v0.16b, %[vzero].16b, v18.16b
\n
"
"bif v1.16b, %[vzero].16b, v19.16b
\n
"
"bif v2.16b, %[vzero].16b, v18.16b
\n
"
"bif v3.16b, %[vzero].16b, v19.16b
\n
"
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v4.16b, %[vzero].16b, v18.16b
\n
"
"bif v5.16b, %[vzero].16b, v19.16b
\n
"
"bif v6.16b, %[vzero].16b, v18.16b
\n
"
"bif v7.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s, v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v8.16b, %[vzero].16b, v18.16b
\n
"
"bif v9.16b, %[vzero].16b, v19.16b
\n
"
"bif v10.16b, %[vzero].16b, v18.16b
\n
"
"bif v11.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
"ld1 {v18.4s}, [%[rmask]]
\n
"
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s
\n
"
/* relu */
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v12.16b, v22.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v13.4s, v13.4s, %[vzero].4s
\n
"
/* relu */
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v13.16b, v23.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s
\n
"
/* relu */
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v14.16b, v24.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmax v15.4s, v15.4s, %[vzero].4s
\n
"
/* relu */
"bif v15.16b, v25.16b, v18.16b
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
// end
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
}
#else
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
din0_ptr
=
nullptr
;
const
float
*
din1_ptr
=
nullptr
;
const
float
*
din2_ptr
=
nullptr
;
const
float
*
din3_ptr
=
nullptr
;
float
*
doutr0
=
nullptr
;
float
*
doutr1
=
nullptr
;
float
*
ptr_zero
=
const_cast
<
float
*>
(
zero
);
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
2
)
{
//! process top pad pad_h = 1
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
doutr0
=
dout_channel
;
doutr1
=
dout_channel
+
w_out
;
dr0
=
dr2
;
dr1
=
dr3
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
//! process bottom pad
if
(
i
+
3
>=
h_in
)
{
switch
(
i
+
3
-
h_in
)
{
case
3
:
din1_ptr
=
zero_ptr
;
case
2
:
din2_ptr
=
zero_ptr
;
case
1
:
din3_ptr
=
zero_ptr
;
case
0
:
din3_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
2
>
h_out
)
{
doutr1
=
write_ptr
;
}
int
cnt
=
tile_w
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
asm
volatile
(
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// mid
"1: @ right pad entry
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point
\n
"
// right
"3: @ right pad entry
\n
"
"cmp %[remain], #1 @ check whether has "
"mid cols
\n
"
"blt 0f @ jump to main loop start "
"point
\n
"
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d31}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d19}, [%[rmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[rmask]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d8, d16, d19 @ bit select, deal with right pad
\n
"
"vbif d9, d17, d23 @ bit select, deal with right pad
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad
\n
"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
din3_ptr
]
"+r"
(
din3_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
dout_channel
+=
2
*
w_out
;
}
//! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, with reulu
*/
// w_in > 7
void
conv_depthwise_3x3s2p0_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
int
tile_w
=
w_out
>>
2
;
int
cnt_remain
=
w_out
%
4
;
unsigned
int
size_right_remain
=
(
unsigned
int
)(
w_in
-
(
tile_w
<<
3
));
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
uint32x4_t
wmask
=
vcgtq_s32
(
vdupq_n_s32
(
cnt_remain
),
vld1q_s32
(
out_pad_idx
));
// 0 1 2 3
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
unsigned
int
dmask
[
12
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
vst1q_u32
(
dmask
+
8
,
wmask
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
wbias
;
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
bias_c
=
bias
[
i
];
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
const
float
*
din3_ptr
=
dr3
;
const
float
*
din4_ptr
=
dr4
;
float
*
doutr0
=
dout_channel
;
float
*
doutr0_ptr
=
nullptr
;
float
*
doutr1_ptr
=
nullptr
;
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
h_out
;
i
+=
2
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
din4_ptr
=
dr4
;
doutr0_ptr
=
doutr0
;
doutr1_ptr
=
doutr0
+
w_out
;
dr0
=
dr4
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
//! process bottom pad
if
(
i
+
4
>=
h_in
)
{
switch
(
i
+
4
-
h_in
)
{
case
4
:
din1_ptr
=
zero_ptr
;
case
3
:
din2_ptr
=
zero_ptr
;
case
2
:
din3_ptr
=
zero_ptr
;
case
1
:
din4_ptr
=
zero_ptr
;
case
0
:
din4_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process output pad
if
(
i
+
2
>
h_out
)
{
doutr1_ptr
=
write_ptr
;
}
int
cnt
=
tile_w
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"prfm pldl1keep, [%[inptr0]]
\n
"
"prfm pldl1keep, [%[inptr1]]
\n
"
"prfm pldl1keep, [%[inptr2]]
\n
"
"prfm pldl1keep, [%[inptr3]]
\n
"
"prfm pldl1keep, [%[inptr4]]
\n
"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
// mid
"2:
\n
"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"fmax v16.4s, v16.4s, %[vzero].4s
\n
"
/* relu */
"fadd v17.4s, v17.4s, v13.4s
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"fmax v17.4s, v17.4s, %[vzero].4s
\n
"
/* relu */
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 4f
\n
"
"3:
\n
"
"bif v0.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]]
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
"ld1 {v1.4s}, [%[outptr1]]
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"fmax v16.4s, v16.4s, %[vzero].4s
\n
"
/* relu */
"fadd v17.4s, v17.4s, v13.4s
\n
"
"bif v16.16b, v0.16b, %[wmask].16b
\n
"
// pipei
"fadd v17.4s, v17.4s, v14.4s
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fmax v17.4s, v17.4s, %[vzero].4s
\n
"
/* relu */
"bif v17.16b, v1.16b, %[wmask].16b
\n
"
// pipei
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"4:
\n
"
:
[
inptr0
]
"+r"
(
din0_ptr
),
[
inptr1
]
"+r"
(
din1_ptr
),
[
inptr2
]
"+r"
(
din2_ptr
),
[
inptr3
]
"+r"
(
din3_ptr
),
[
inptr4
]
"+r"
(
din4_ptr
),
[
outptr0
]
"+r"
(
doutr0_ptr
),
[
outptr1
]
"+r"
(
doutr1_ptr
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
remain
]
"r"
(
cnt_remain
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
wmask
]
"w"
(
wmask
),
[
vbias
]
"w"
(
wbias
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
);
doutr0
=
doutr0
+
2
*
w_out
;
}
#else
for
(
int
i
=
0
;
i
<
h_out
;
i
++
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
doutr0_ptr
=
doutr0
;
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
//! process bottom pad
if
(
i
+
2
>
h_in
)
{
switch
(
i
+
2
-
h_in
)
{
case
2
:
din1_ptr
=
zero_ptr
;
case
1
:
din2_ptr
=
zero_ptr
;
default:
break
;
}
}
int
cnt
=
tile_w
;
unsigned
int
*
mask_ptr
=
dmask
;
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"vmov.u32 q9, #0
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1
\n
"
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
// mid
"2:
\n
"
"vext.32 q6, q10, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"subs %[cnt], #1
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 3f
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output
\n
"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask
\n
"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vbif.f32 q3, q10, q11 @ write mask
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"3:
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
outptr
]
"+r"
(
doutr0_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
remain
]
"r"
(
cnt_remain
),
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
doutr0
=
doutr0
+
w_out
;
}
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void
conv_depthwise_3x3s1p0_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
const
float
zero_ptr
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
uint32x4_t
vmask_rp1
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
),
vdupq_n_s32
(
6
-
w_in
));
uint32x4_t
vmask_rp2
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
+
4
),
vdupq_n_s32
(
6
-
w_in
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
wbias
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
float
out_buf1
[
4
];
float
out_buf2
[
4
];
float
trash_buf
[
4
];
float
*
doutr0
=
dout_channel
;
float
*
doutr1
=
dout_channel
+
w_out
;
for
(
int
j
=
0
;
j
<
h_out
;
j
+=
2
)
{
const
float
*
dr0
=
din_channel
+
j
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
doutr0
=
dout_channel
+
j
*
w_out
;
doutr1
=
doutr0
+
w_out
;
if
(
j
+
3
>=
h_in
)
{
switch
(
j
+
3
-
h_in
)
{
case
3
:
dr1
=
zero_ptr
;
case
2
:
dr2
=
zero_ptr
;
case
1
:
dr3
=
zero_ptr
;
doutr1
=
trash_buf
;
case
0
:
dr3
=
zero_ptr
;
doutr1
=
trash_buf
;
default:
break
;
}
}
#ifdef __aarch64__
asm
volatile
(
"prfm pldl1keep, [%[din0]]
\n
"
"prfm pldl1keep, [%[din1]]
\n
"
"prfm pldl1keep, [%[din2]]
\n
"
"prfm pldl1keep, [%[din3]]
\n
"
"ld1 {v0.4s, v1.4s}, [%[din0]]
\n
"
"ld1 {v2.4s, v3.4s}, [%[din1]]
\n
"
"ld1 {v4.4s, v5.4s}, [%[din2]]
\n
"
"ld1 {v6.4s, v7.4s}, [%[din3]]
\n
"
"bif v0.16b, %[zero].16b, %[mask1].16b
\n
"
// d0_1234
"bif v1.16b, %[zero].16b, %[mask2].16b
\n
"
// d0_1234
"bif v2.16b, %[zero].16b, %[mask1].16b
\n
"
// d1_1234
"bif v3.16b, %[zero].16b, %[mask2].16b
\n
"
// d1_1234
"bif v4.16b, %[zero].16b, %[mask1].16b
\n
"
// d2_1234
"bif v5.16b, %[zero].16b, %[mask2].16b
\n
"
// d2_1234
"bif v6.16b, %[zero].16b, %[mask1].16b
\n
"
// d3_1234
"bif v7.16b, %[zero].16b, %[mask2].16b
\n
"
// d3_1234
"ext v8.16b, v0.16b, v1.16b, #4
\n
"
// d1_2345
"ext v9.16b, v0.16b, v1.16b, #8
\n
"
// d1_3450
"and v12.16b, %[vbias].16b, %[vbias].16b
\n
"
// v12 = vbias
"and v13.16b, %[vbias].16b, %[vbias].16b
\n
"
// v13 = vbias
// r0
"fmul v10.4s, v0.4s, %[wr0].s[0]
\n
"
// d0_1234 * w0[0]
"fmul v11.4s, v8.4s, %[wr0].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v12.4s, v9.4s, %[wr0].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v2.16b, v3.16b, #4
\n
"
// d1_2345
"ext v9.16b, v2.16b, v3.16b, #8
\n
"
// d1_3450
// r1
"fmul v14.4s, v2.4s, %[wr0].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v10.4s, v2.4s, %[wr1].s[0]
\n
"
// d0_1234 * w0[0]
"fmul v15.4s, v8.4s, %[wr0].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr1].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr0].s[2]
\n
"
// d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr1].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v4.16b, v5.16b, #4
\n
"
// d1_2345
"ext v9.16b, v4.16b, v5.16b, #8
\n
"
// d1_3450
// r2
"fmla v14.4s, v4.4s, %[wr1].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v10.4s, v4.4s, %[wr2].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr1].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr2].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr1].s[2]
\n
"
// d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr2].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v6.16b, v7.16b, #4
\n
"
// d1_2345
"ext v9.16b, v6.16b, v7.16b, #8
\n
"
// d1_3450
// r3
"fmla v14.4s, v6.4s, %[wr2].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr2].s[1]
\n
"
// d1_2345 * w0[1]
"fadd v12.4s, v12.4s, v10.4s
\n
"
"fmla v13.4s, v9.4s, %[wr2].s[2]
\n
"
// d0_3456 * w0[2]
"fadd v12.4s, v12.4s, v11.4s
\n
"
// out1
"fadd v13.4s, v13.4s, v14.4s
\n
"
// out2
"fadd v13.4s, v13.4s, v15.4s
\n
"
// out2
"prfm pldl1keep, [%[out1]]
\n
"
"prfm pldl1keep, [%[out2]]
\n
"
"st1 {v12.4s}, [%[out1]]
\n
"
"st1 {v13.4s}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
zero
]
"w"
(
vzero
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
#else
unsigned
int
*
vmask_ptr
=
vmask
;
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
asm
volatile
(
"pld [%[din0]]
\n
"
"pld [%[din1]]
\n
"
"pld [%[din2]]
\n
"
"pld [%[din3]]
\n
"
"vld1.32 {d16-d18}, [%[din0]] @ load din r0
\n
"
"vld1.32 {d20-d22}, [%[din1]] @ load din r1
\n
"
"vld1.32 {d24-d26}, [%[din2]] @ load din r2
\n
"
"vld1.32 {d28-d30}, [%[din3]] @ load din r3
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vadd.f32 q4, q4, q10 @ q4 += q10
\n
"
"pld [%[out1]]
\n
"
"pld [%[out2]]
\n
"
"vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vadd.f32 q4, q4, q11 @ q4 += q10
\n
"
"vadd.f32 q5, q5, q8 @ q4 += q10
\n
"
"vadd.f32 q5, q5, q9 @ q4 += q10
\n
"
"vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer
\n
"
"vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
};
}
// end of processing heights
}
// end of processing channels
}
// end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void
conv_depthwise_3x3s2p0_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
float
zeros
[
8
]
=
{
0.0
f
};
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
unsigned
int
dmask
[
8
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
bias_c
=
bias
[
i
];
}
float32x4_t
vbias
=
vdupq_n_f32
(
bias_c
);
float
out_buf
[
4
];
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
for
(
int
j
=
0
;
j
<
h_out
;
++
j
)
{
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
unsigned
int
*
mask_ptr
=
dmask
;
#ifdef __aarch64__
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0
\n
"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32
\n
"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32
\n
"
// v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32
\n
"
// v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32
\n
"
// v14={0,2,4,6}
// v15={1,3,5,7}
"and v4.16b, %[bias].16b, %[bias].16b
\n
"
// v10 = vbias
"bif v10.16b, v9.16b, v6.16b
\n
"
"bif v11.16b, v9.16b, v7.16b
\n
"
"bif v12.16b, v9.16b, v6.16b
\n
"
"bif v13.16b, v9.16b, v7.16b
\n
"
"bif v14.16b, v9.16b, v6.16b
\n
"
"bif v15.16b, v9.16b, v7.16b
\n
"
"ext v6.16b, v10.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"ext v7.16b, v12.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"ext v8.16b, v14.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"fmla v4.4s, v10.4s, %[wr0].s[0]
\n
"
// 0246 * w00
"fmul v5.4s, v11.4s, %[wr0].s[1]
\n
"
// 1357 * w01
"fmul v16.4s, v6.4s, %[wr0].s[2]
\n
"
// 2468 * w02
"fmla v4.4s, v12.4s, %[wr1].s[0]
\n
"
// v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[1]
\n
"
// v13 * w12
"fmla v16.4s, v7.4s, %[wr1].s[2]
\n
"
// v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[0]
\n
"
// v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[1]
\n
"
// v15 * w21
"fmla v16.4s, v8.4s, %[wr2].s[2]
\n
"
// v8 * w22
"fadd v4.4s, v4.4s, v5.4s
\n
"
"fadd v4.4s, v4.4s, v16.4s
\n
"
// "fadd v4.4s, v4.4s, %[bias].4s \n"
"st1 {v4.4s}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"w"
(
vbias
),
[
out
]
"r"
(
out_buf
)
:
"cc"
,
"memory"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
);
#else
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,0}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q7 = {2,4,6,0}
"vext.32 q8, q14, q9, #1 @ shift left 1
\n
"
// q8 = {2,4,6,0}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// {0,2,4,6}
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// {1,3,5,7}
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// {2,4,6,0}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q12 * w11
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q13 * w12
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q7 * w10
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q14 * w20
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q15 * w21
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q8 * w22
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vst1.32 {d6-d7}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
),
[
out
]
"r"
(
out_buf
),
[
mask_ptr
]
"r"
(
dmask
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
dout_channel
++
=
out_buf
[
w
];
}
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void
conv_depthwise_3x3s1p0_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
const
float
zero_ptr
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
uint32x4_t
vmask_rp1
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
),
vdupq_n_s32
(
6
-
w_in
));
uint32x4_t
vmask_rp2
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
+
4
),
vdupq_n_s32
(
6
-
w_in
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
wbias
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
float
out_buf1
[
4
];
float
out_buf2
[
4
];
float
trash_buf
[
4
];
float
*
doutr0
=
dout_channel
;
float
*
doutr1
=
dout_channel
+
w_out
;
for
(
int
j
=
0
;
j
<
h_out
;
j
+=
2
)
{
const
float
*
dr0
=
din_channel
+
j
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
doutr0
=
dout_channel
+
j
*
w_out
;
doutr1
=
doutr0
+
w_out
;
if
(
j
+
3
>=
h_in
)
{
switch
(
j
+
3
-
h_in
)
{
case
3
:
dr1
=
zero_ptr
;
case
2
:
dr2
=
zero_ptr
;
case
1
:
dr3
=
zero_ptr
;
doutr1
=
trash_buf
;
case
0
:
dr3
=
zero_ptr
;
doutr1
=
trash_buf
;
default:
break
;
}
}
#ifdef __aarch64__
asm
volatile
(
"prfm pldl1keep, [%[din0]]
\n
"
"prfm pldl1keep, [%[din1]]
\n
"
"prfm pldl1keep, [%[din2]]
\n
"
"prfm pldl1keep, [%[din3]]
\n
"
"ld1 {v0.4s, v1.4s}, [%[din0]]
\n
"
"ld1 {v2.4s, v3.4s}, [%[din1]]
\n
"
"ld1 {v4.4s, v5.4s}, [%[din2]]
\n
"
"ld1 {v6.4s, v7.4s}, [%[din3]]
\n
"
"bif v0.16b, %[zero].16b, %[mask1].16b
\n
"
// d0_1234
"bif v1.16b, %[zero].16b, %[mask2].16b
\n
"
// d0_1234
"bif v2.16b, %[zero].16b, %[mask1].16b
\n
"
// d1_1234
"bif v3.16b, %[zero].16b, %[mask2].16b
\n
"
// d1_1234
"bif v4.16b, %[zero].16b, %[mask1].16b
\n
"
// d2_1234
"bif v5.16b, %[zero].16b, %[mask2].16b
\n
"
// d2_1234
"bif v6.16b, %[zero].16b, %[mask1].16b
\n
"
// d3_1234
"bif v7.16b, %[zero].16b, %[mask2].16b
\n
"
// d3_1234
"ext v8.16b, v0.16b, v1.16b, #4
\n
"
// d1_2345
"ext v9.16b, v0.16b, v1.16b, #8
\n
"
// d1_3450
"and v12.16b, %[vbias].16b, %[vbias].16b
\n
"
// v12 = vbias
"and v13.16b, %[vbias].16b, %[vbias].16b
\n
"
// v13 = vbias
// r0
"fmul v10.4s, v0.4s, %[wr0].s[0]
\n
"
// d0_1234 * w0[0]
"fmul v11.4s, v8.4s, %[wr0].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v12.4s, v9.4s, %[wr0].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v2.16b, v3.16b, #4
\n
"
// d1_2345
"ext v9.16b, v2.16b, v3.16b, #8
\n
"
// d1_3450
// r1
"fmul v14.4s, v2.4s, %[wr0].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v10.4s, v2.4s, %[wr1].s[0]
\n
"
// d0_1234 * w0[0]
"fmul v15.4s, v8.4s, %[wr0].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr1].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr0].s[2]
\n
"
// d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr1].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v4.16b, v5.16b, #4
\n
"
// d1_2345
"ext v9.16b, v4.16b, v5.16b, #8
\n
"
// d1_3450
// r2
"fmla v14.4s, v4.4s, %[wr1].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v10.4s, v4.4s, %[wr2].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr1].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr2].s[1]
\n
"
// d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr1].s[2]
\n
"
// d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr2].s[2]
\n
"
// d0_3456 * w0[2]
"ext v8.16b, v6.16b, v7.16b, #4
\n
"
// d1_2345
"ext v9.16b, v6.16b, v7.16b, #8
\n
"
// d1_3450
// r3
"fmla v14.4s, v6.4s, %[wr2].s[0]
\n
"
// d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr2].s[1]
\n
"
// d1_2345 * w0[1]
"fadd v12.4s, v12.4s, v10.4s
\n
"
"fmla v13.4s, v9.4s, %[wr2].s[2]
\n
"
// d0_3456 * w0[2]
"fadd v12.4s, v12.4s, v11.4s
\n
"
// out1
"fadd v13.4s, v13.4s, v14.4s
\n
"
// out2
"fadd v13.4s, v13.4s, v15.4s
\n
"
// out2
"prfm pldl1keep, [%[out1]]
\n
"
"prfm pldl1keep, [%[out2]]
\n
"
"fmax v12.4s, v12.4s, %[zero].4s
\n
"
"fmax v13.4s, v13.4s, %[zero].4s
\n
"
"st1 {v12.4s}, [%[out1]]
\n
"
"st1 {v13.4s}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
zero
]
"w"
(
vzero
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
#else
unsigned
int
*
vmask_ptr
=
vmask
;
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
asm
volatile
(
"pld [%[din0]]
\n
"
"pld [%[din1]]
\n
"
"pld [%[din2]]
\n
"
"pld [%[din3]]
\n
"
"vld1.32 {d16-d18}, [%[din0]] @ load din r0
\n
"
"vld1.32 {d20-d22}, [%[din1]] @ load din r1
\n
"
"vld1.32 {d24-d26}, [%[din2]] @ load din r2
\n
"
"vld1.32 {d28-d30}, [%[din3]] @ load din r3
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vadd.f32 q4, q4, q10 @ q4 += q10
\n
"
"pld [%[out1]]
\n
"
"pld [%[out2]]
\n
"
"vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vadd.f32 q4, q4, q11 @ q4 += q10
\n
"
"vadd.f32 q5, q5, q8 @ q4 += q10
\n
"
"vadd.f32 q5, q5, q9 @ q4 += q10
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer
\n
"
"vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
};
// doutr0 = doutr1;
// doutr1 += w_out;
}
// end of processing heights
}
// end of processing channels
}
// end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 7
*/
void
conv_depthwise_3x3s2p0_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
float
zeros
[
8
]
=
{
0.0
f
};
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
unsigned
int
dmask
[
8
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
bias_c
=
bias
[
i
];
}
float32x4_t
vbias
=
vdupq_n_f32
(
bias_c
);
float
out_buf
[
4
];
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
for
(
int
j
=
0
;
j
<
h_out
;
++
j
)
{
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
unsigned
int
*
mask_ptr
=
dmask
;
#ifdef __aarch64__
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0
\n
"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]]
\n
"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32
\n
"
// v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32
\n
"
// v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32
\n
"
// v14={0,2,4,6}
// v15={1,3,5,7}
"and v4.16b, %[bias].16b, %[bias].16b
\n
"
// v10 = vbias
"bif v10.16b, v9.16b, v6.16b
\n
"
"bif v11.16b, v9.16b, v7.16b
\n
"
"bif v12.16b, v9.16b, v6.16b
\n
"
"bif v13.16b, v9.16b, v7.16b
\n
"
"bif v14.16b, v9.16b, v6.16b
\n
"
"bif v15.16b, v9.16b, v7.16b
\n
"
"ext v6.16b, v10.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"ext v7.16b, v12.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"ext v8.16b, v14.16b, v9.16b, #4
\n
"
// v6 =
// {2,4,6,8}
"fmla v4.4s, v10.4s, %[wr0].s[0]
\n
"
// 0246 * w00
"fmul v5.4s, v11.4s, %[wr0].s[1]
\n
"
// 1357 * w01
"fmul v16.4s, v6.4s, %[wr0].s[2]
\n
"
// 2468 * w02
"fmla v4.4s, v12.4s, %[wr1].s[0]
\n
"
// v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[1]
\n
"
// v13 * w12
"fmla v16.4s, v7.4s, %[wr1].s[2]
\n
"
// v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[0]
\n
"
// v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[1]
\n
"
// v15 * w21
"fmla v16.4s, v8.4s, %[wr2].s[2]
\n
"
// v8 * w22
"fadd v4.4s, v4.4s, v5.4s
\n
"
"fadd v4.4s, v4.4s, v16.4s
\n
"
"fmax v4.4s, v4.4s, v9.4s
\n
"
// "fadd v4.4s, v4.4s, %[bias].4s \n"
"st1 {v4.4s}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"w"
(
vbias
),
[
out
]
"r"
(
out_buf
),
[
mask_ptr
]
"r"
(
mask_ptr
)
:
"cc"
,
"memory"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
);
#else
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,0}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q7 = {2,4,6,0}
"vext.32 q8, q14, q9, #1 @ shift left 1
\n
"
// q8 = {2,4,6,0}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// {0,2,4,6}
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// {1,3,5,7}
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// {2,4,6,0}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q12 * w11
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q13 * w12
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q7 * w10
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q14 * w20
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q15 * w21
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q8 * w22
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vst1.32 {d6-d7}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
),
[
out
]
"r"
(
out_buf
),
[
mask_ptr
]
"r"
(
mask_ptr
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
dout_channel
++
=
out_buf
[
w
];
}
}
}
}
}
}
// namespace depthwise
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
#endif
src/operators/math/depthwise/faster_depthwise_conv3x3p1.cpp
0 → 100644
浏览文件 @
16a0bd75
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#include "framework/context.h"
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
namespace
depthwise
{
void
conv_depthwise_3x3s1p1_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s1p1_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s2p1_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s2p1_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s1p1_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s1p1_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3s2p1_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
//! for input width <= 4
void
conv_depthwise_3x3s2p1_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
);
void
conv_depthwise_3x3p1
(
const
float
*
din
,
float
*
dout
,
int
num
,
int
ch_out
,
int
h_out
,
int
w_out
,
int
ch_in
,
int
h_in
,
int
w_in
,
const
float
*
weights
,
const
float
*
bias
,
int
stride
,
bool
flag_bias
,
bool
flag_relu
)
{
if
(
stride
==
1
)
{
if
(
flag_relu
)
{
if
(
w_in
>
4
)
{
conv_depthwise_3x3s1p1_bias_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s1p1_bias_s_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
else
{
if
(
w_in
>
4
)
{
conv_depthwise_3x3s1p1_bias
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s1p1_bias_s
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
}
else
{
//! stride = 2
if
(
flag_relu
)
{
if
(
w_in
>
7
)
{
conv_depthwise_3x3s2p1_bias_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s2p1_bias_s_relu
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
else
{
if
(
w_in
>
7
)
{
conv_depthwise_3x3s2p1_bias
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
else
{
conv_depthwise_3x3s2p1_bias_s
(
dout
,
din
,
weights
,
bias
,
flag_bias
,
num
,
ch_in
,
h_in
,
w_in
,
h_out
,
w_out
);
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
// 4line
void
conv_depthwise_3x3s1p1_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
//! for 4x6 convolution window
const
unsigned
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
// printf("conv3x3_dw start \n");
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
int
w_stride
=
9
;
int
tile_w
=
(
w_in
+
3
)
>>
2
;
int
cnt_col
=
tile_w
-
2
;
unsigned
int
size_pad_right
=
(
unsigned
int
)(
1
+
(
tile_w
<<
2
)
-
w_in
);
uint32x4_t
vmask_rp1
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_rp2
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
+
4
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_result
=
vcgtq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
unsigned
int
rmask
[
4
];
vst1q_u32
(
rmask
,
vmask_result
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
#ifdef __aarch64__
for
(
int
c
=
0
;
c
<
ch_in
;
c
++
)
{
float
*
dout_ptr
=
dout_batch
+
c
*
size_out_channel
;
const
float
*
din_ch_ptr
=
din_batch
+
c
*
size_in_channel
;
float
bias_val
=
flag_bias
?
bias
[
c
]
:
0.
f
;
float
vbias
[
4
]
=
{
bias_val
,
bias_val
,
bias_val
,
bias_val
};
const
float
*
wei_ptr
=
weights
+
c
*
w_stride
;
float32x4_t
wr0
=
vld1q_f32
(
wei_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
wei_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
wei_ptr
+
6
);
float
*
doutr0
=
dout_ptr
;
float
*
doutr1
=
doutr0
+
w_out
;
float
*
doutr2
=
doutr1
+
w_out
;
float
*
doutr3
=
doutr2
+
w_out
;
const
float
*
dr0
=
din_ch_ptr
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
dr5
=
dr4
+
w_in
;
const
float
*
din_ptr0
=
dr0
;
const
float
*
din_ptr1
=
dr1
;
const
float
*
din_ptr2
=
dr2
;
const
float
*
din_ptr3
=
dr3
;
const
float
*
din_ptr4
=
dr4
;
const
float
*
din_ptr5
=
dr5
;
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
4
)
{
//! process top pad pad_h = 1
din_ptr0
=
dr0
;
din_ptr1
=
dr1
;
din_ptr2
=
dr2
;
din_ptr3
=
dr3
;
din_ptr4
=
dr4
;
din_ptr5
=
dr5
;
doutr0
=
dout_ptr
;
doutr1
=
doutr0
+
w_out
;
doutr2
=
doutr1
+
w_out
;
doutr3
=
doutr2
+
w_out
;
if
(
i
==
0
)
{
din_ptr0
=
zero_ptr
;
din_ptr1
=
dr0
;
din_ptr2
=
dr1
;
din_ptr3
=
dr2
;
din_ptr4
=
dr3
;
din_ptr5
=
dr4
;
dr0
=
dr3
;
dr1
=
dr4
;
dr2
=
dr5
;
}
else
{
dr0
=
dr4
;
dr1
=
dr5
;
dr2
=
dr1
+
w_in
;
}
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
dr5
=
dr4
+
w_in
;
//! process bottom pad
if
(
i
+
5
>
h_in
)
{
switch
(
i
+
5
-
h_in
)
{
case
5
:
din_ptr1
=
zero_ptr
;
case
4
:
din_ptr2
=
zero_ptr
;
case
3
:
din_ptr3
=
zero_ptr
;
case
2
:
din_ptr4
=
zero_ptr
;
case
1
:
din_ptr5
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
4
>
h_out
)
{
switch
(
i
+
4
-
h_out
)
{
case
3
:
doutr1
=
write_ptr
;
case
2
:
doutr2
=
write_ptr
;
case
1
:
doutr3
=
write_ptr
;
default:
break
;
}
}
int
cnt
=
cnt_col
;
asm
volatile
(
"PRFM PLDL1KEEP, [%[din_ptr0]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr1]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr2]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr3]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr4]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr5]]
\n
"
"movi v21.4s, #0x0
\n
"
/* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ext v16.16b, %[vzero].16b, v0.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
// left
// r0
"fmla v12.4s, v0.4s, %[w0].s[1]
\n
"
/* outr00 += din0_0123 *
w0[1]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"sub %[din_ptr0], %[din_ptr0], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr1], %[din_ptr1], #4
\n
"
/* din_ptr0-- */
"fmla v12.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0012 *
w0[0]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"sub %[din_ptr2], %[din_ptr2], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr3], %[din_ptr3], #4
\n
"
/* din_ptr0-- */
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_1234 *
w0[2]*/
"ext v16.16b, %[vzero].16b, v2.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[1]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v2.4s, %[w1].s[1]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"sub %[din_ptr4], %[din_ptr4], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr5], %[din_ptr5], #4
\n
"
/* din_ptr0-- */
"fmla v13.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v4.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[1]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v4.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v12.4s , v4.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v6.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[1]
\n
"
/*outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v6.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v13.4s , v6.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v8.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v14.4s , v8.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
/* vst1q_f32() */
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
/* vst1q_f32() */
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v10.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234 */
// r5
"fmla v15.4s , v10.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
/* vst1q_f32() */
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
/* vst1q_f32() */
"cmp %[cnt], #1
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"blt 3f
\n
"
// mid
"1:
\n
"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"bne 1b
\n
"
// right
"3:
\n
"
"ld1 {v18.4s, v19.4s}, [%[vmask]]
\n
"
"ld1 {v22.4s}, [%[doutr0]]
\n
"
"ld1 {v23.4s}, [%[doutr1]]
\n
"
"ld1 {v24.4s}, [%[doutr2]]
\n
"
"ld1 {v25.4s}, [%[doutr3]]
\n
"
"bif v0.16b, %[vzero].16b, v18.16b
\n
"
"bif v1.16b, %[vzero].16b, v19.16b
\n
"
"bif v2.16b, %[vzero].16b, v18.16b
\n
"
"bif v3.16b, %[vzero].16b, v19.16b
\n
"
"bif v4.16b, %[vzero].16b, v18.16b
\n
"
"bif v5.16b, %[vzero].16b, v19.16b
\n
"
"bif v6.16b, %[vzero].16b, v18.16b
\n
"
"bif v7.16b, %[vzero].16b, v19.16b
\n
"
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v8.16b, %[vzero].16b, v18.16b
\n
"
"bif v9.16b, %[vzero].16b, v19.16b
\n
"
"bif v10.16b, %[vzero].16b, v18.16b
\n
"
"bif v11.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s, v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v18.4s}, [%[rmask]]
\n
"
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v12.16b, v22.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v13.16b, v23.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v14.16b, v24.16b, v18.16b
\n
"
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"bif v15.16b, v25.16b, v18.16b
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
}
#else
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
din0_ptr
=
nullptr
;
const
float
*
din1_ptr
=
nullptr
;
const
float
*
din2_ptr
=
nullptr
;
const
float
*
din3_ptr
=
nullptr
;
float
*
doutr0
=
nullptr
;
float
*
doutr1
=
nullptr
;
float
*
ptr_zero
=
const_cast
<
float
*>
(
zero
);
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
2
)
{
//! process top pad pad_h = 1
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
doutr0
=
dout_channel
;
doutr1
=
dout_channel
+
w_out
;
// unsigned int* rst_mask = rmask;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
din3_ptr
=
dr2
;
dr0
=
dr1
;
dr1
=
dr2
;
dr2
=
dr3
;
dr3
=
dr2
+
w_in
;
}
else
{
dr0
=
dr2
;
dr1
=
dr3
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
}
//! process bottom pad
if
(
i
+
3
>
h_in
)
{
switch
(
i
+
3
-
h_in
)
{
case
3
:
din1_ptr
=
zero_ptr
;
case
2
:
din2_ptr
=
zero_ptr
;
case
1
:
din3_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
2
>
h_out
)
{
doutr1
=
write_ptr
;
}
int
cnt
=
cnt_col
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
asm
volatile
(
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1
\n
"
"vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2
\n
"
"vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vext.32 q6, %q[vzero], q8, #3 @ 0012
\n
"
"vext.32 q7, q8, q9, #1 @ 1234
\n
"
// left
// r0
"vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"sub %[din0_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q10, #3 @ 0012
\n
"
"vext.32 q7, q10, q11, #1 @ 1234
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q12, #3 @ 0012
\n
"
"vext.32 q7, q12, q13, #1 @ 1234
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q14, #3 @ 0012
\n
"
"vext.32 q7, q14, q15, #1 @ 1234
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"cmp %[cnt], #1 @ check whether has "
"mid cols
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"blt 3f @ jump to main loop start "
"point
\n
"
// mid
"1: @ right pad entry
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point
\n
"
// right
"3: @ right pad entry
\n
"
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d31}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d19}, [%[rmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[rmask]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d8, d16, d19 @ bit select, deal with right pad
\n
"
"vbif d9, d17, d23 @ bit select, deal with right pad
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad
\n
"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
din3_ptr
]
"+r"
(
din3_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
dout_channel
+=
2
*
w_out
;
}
//! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
// w_in > 7
void
conv_depthwise_3x3s2p1_bias
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
int
size_pad_bottom
=
h_out
*
2
-
h_in
;
int
cnt_col
=
(
w_out
>>
2
)
-
2
;
int
size_right_remain
=
w_in
-
(
7
+
cnt_col
*
8
);
if
(
size_right_remain
>=
9
)
{
cnt_col
++
;
size_right_remain
-=
8
;
}
int
cnt_remain
=
(
size_right_remain
==
8
)
?
4
:
(
w_out
%
4
);
//
int
size_right_pad
=
w_out
*
2
-
w_in
;
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
uint32x4_t
wmask
=
vcgtq_s32
(
vdupq_n_s32
(
cnt_remain
),
vld1q_s32
(
out_pad_idx
));
// 0 1 2 3
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
unsigned
int
dmask
[
12
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
vst1q_u32
(
dmask
+
8
,
wmask
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
wbias
;
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
bias_c
=
bias
[
i
];
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
const
float
*
din3_ptr
=
dr3
;
const
float
*
din4_ptr
=
dr4
;
float
*
doutr0
=
dout_channel
;
float
*
doutr0_ptr
=
nullptr
;
float
*
doutr1_ptr
=
nullptr
;
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
4
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
din4_ptr
=
dr4
;
doutr0_ptr
=
doutr0
;
doutr1_ptr
=
doutr0
+
w_out
;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
din3_ptr
=
dr2
;
din4_ptr
=
dr3
;
dr0
=
dr3
;
dr1
=
dr4
;
}
else
{
dr0
=
dr4
;
dr1
=
dr0
+
w_in
;
}
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
//! process bottom pad
if
(
i
+
4
>
h_in
)
{
switch
(
i
+
4
-
h_in
)
{
case
4
:
din1_ptr
=
zero_ptr
;
case
3
:
din2_ptr
=
zero_ptr
;
case
2
:
din3_ptr
=
zero_ptr
;
case
1
:
din4_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process output pad
if
(
i
/
2
+
2
>
h_out
)
{
doutr1_ptr
=
write_ptr
;
}
int
cnt
=
cnt_col
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"prfm pldl1keep, [%[inptr0]]
\n
"
"prfm pldl1keep, [%[inptr1]]
\n
"
"prfm pldl1keep, [%[inptr2]]
\n
"
"prfm pldl1keep, [%[inptr3]]
\n
"
"prfm pldl1keep, [%[inptr4]]
\n
"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"ext v10.16b, %[vzero].16b, v1.16b, #12
\n
"
// v10 = {0,1,3,5}
// r0
"fmul v11.4s, v0.4s, %[w0].s[1]
\n
"
// {0,2,4,6} * w01
"fmul v12.4s, v1.4s, %[w0].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w0].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v3.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr0], %[inptr0], #4
\n
"
"sub %[inptr1], %[inptr1], #4
\n
"
// r1
"fmla v11.4s, v2.4s, %[w1].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v12.4s, v3.4s, %[w1].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w1].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v5.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr2], %[inptr2], #4
\n
"
"sub %[inptr3], %[inptr3], #4
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v11.4s, v4.4s, %[w2].s[1]
\n
"
// {0,2,4,6} * w01
"fmul v14.4s, v5.4s, %[w0].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v12.4s, v5.4s, %[w2].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w0].s[0]
\n
"
// {0,1,3,5} * w00
"fmla v16.4s, v10.4s, %[w2].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v7.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr4], %[inptr4], #4
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v14.4s, v7.4s, %[w1].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w1].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v9.16b, #12
\n
"
// v10 = {0,1,3,5}
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v14.4s, v9.4s, %[w2].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w2].s[0]
\n
"
// {0,1,3,5} * w00
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"fadd v17.4s, v17.4s, v13.4s
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"fadd v17.4s, v17.4s, v14.4s
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"cmp %[cnt], #1
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"blt 1f
\n
"
// mid
"2:
\n
"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v13.4s
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"subs %[cnt], %[cnt], #1
\n
"
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 4f
\n
"
"3:
\n
"
"bif v0.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]]
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
"ld1 {v1.4s}, [%[outptr1]]
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"bif v16.16b, v0.16b, %[wmask].16b
\n
"
// pipei
"fadd v17.4s, v17.4s, v13.4s
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"bif v17.16b, v1.16b, %[wmask].16b
\n
"
// pipei
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"4:
\n
"
:
[
inptr0
]
"+r"
(
din0_ptr
),
[
inptr1
]
"+r"
(
din1_ptr
),
[
inptr2
]
"+r"
(
din2_ptr
),
[
inptr3
]
"+r"
(
din3_ptr
),
[
inptr4
]
"+r"
(
din4_ptr
),
[
outptr0
]
"+r"
(
doutr0_ptr
),
[
outptr1
]
"+r"
(
doutr1_ptr
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
remain
]
"r"
(
cnt_remain
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
wmask
]
"w"
(
wmask
),
[
vbias
]
"w"
(
wbias
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
);
doutr0
=
doutr0
+
2
*
w_out
;
}
#else
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
2
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
doutr0_ptr
=
doutr0
;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
dr0
=
dr1
;
dr1
=
dr2
;
dr2
=
dr1
+
w_in
;
}
else
{
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
}
//! process bottom pad
if
(
i
+
2
>
h_in
)
{
switch
(
i
+
2
-
h_in
)
{
case
2
:
din1_ptr
=
zero_ptr
;
case
1
:
din2_ptr
=
zero_ptr
;
default:
break
;
}
}
int
cnt
=
cnt_col
;
unsigned
int
*
mask_ptr
=
dmask
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"vmov.u32 q9, #0
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1
\n
"
// v11={0,2,4,6} v12={1,3,5,7}, q10, q11
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v11={0,2,4,6} v12={1,3,5,7}, q12, q13
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
// v13={0,2,4,6} v14={1,3,5,7}, q14, q15
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vext.32 q6, q9, q11, #3 @ shift right 1 "
"data
\n
"
// q2 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift right 1 "
"data
\n
"
// q6 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift right 1 "
"data
\n
"
// q6 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, "
"out0
\n
"
// q11 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, "
"out0
\n
"
// q12 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, "
"out0
\n
"
// q6 * w00
"sub %[din0_ptr], #4 @ inpitr0 - 1
\n
"
"sub %[din1_ptr], #4 @ inpitr1 - 1
\n
"
"sub %[din2_ptr], #4 @ inpitr2 - 1
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q11 * w01
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q12 * w02
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w00
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, "
"out1
\n
"
// q0 * w01
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, "
"out1
\n
"
// q1 * w02
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, "
"out1
\n
"
// q2 * w00
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"cmp %[cnt], #1
\n
"
"blt 1f
\n
"
// mid
"2:
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vext.32 q6, q10, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"subs %[cnt], #1
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 3f
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output
\n
"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask
\n
"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vbif.f32 q3, q10, q11 @ write mask
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"3:
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
outptr
]
"+r"
(
doutr0_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
remain
]
"r"
(
cnt_remain
),
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
doutr0
=
doutr0
+
w_out
;
}
#endif
}
}
}
// 4line
void
conv_depthwise_3x3s1p1_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
//! for 4x6 convolution window
const
unsigned
int
right_pad_idx
[
8
]
=
{
5
,
4
,
3
,
2
,
1
,
0
,
0
,
0
};
// printf("conv3x3_dw start \n");
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
int
w_stride
=
9
;
int
tile_w
=
(
w_in
+
3
)
>>
2
;
int
tile_h
=
(
h_in
+
3
)
>>
2
;
int
cnt_col
=
tile_w
-
2
;
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
unsigned
int
size_pad_right
=
(
unsigned
int
)(
1
+
(
tile_w
<<
2
)
-
w_in
);
int
size_pad_bottom
=
(
unsigned
int
)(
1
+
(
tile_h
<<
2
)
-
h_in
);
uint32x4_t
vmask_rp1
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_rp2
=
vcgeq_u32
(
vld1q_u32
(
right_pad_idx
+
4
),
vdupq_n_u32
(
size_pad_right
));
uint32x4_t
vmask_result
=
vcgtq_u32
(
vld1q_u32
(
right_pad_idx
),
vdupq_n_u32
(
size_pad_right
));
unsigned
int
vmask
[
8
];
vst1q_u32
(
vmask
,
vmask_rp1
);
vst1q_u32
(
vmask
+
4
,
vmask_rp2
);
unsigned
int
rmask
[
4
];
vst1q_u32
(
rmask
,
vmask_result
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
#ifdef __aarch64__
for
(
int
c
=
0
;
c
<
ch_in
;
c
++
)
{
float
*
dout_ptr
=
dout_batch
+
c
*
size_out_channel
;
const
float
*
din_ch_ptr
=
din_batch
+
c
*
size_in_channel
;
float
bias_val
=
flag_bias
?
bias
[
c
]
:
0.
f
;
float
vbias
[
4
]
=
{
bias_val
,
bias_val
,
bias_val
,
bias_val
};
const
float
*
wei_ptr
=
weights
+
c
*
w_stride
;
float32x4_t
wr0
=
vld1q_f32
(
wei_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
wei_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
wei_ptr
+
6
);
float
*
doutr0
=
dout_ptr
;
float
*
doutr1
=
doutr0
+
w_out
;
float
*
doutr2
=
doutr1
+
w_out
;
float
*
doutr3
=
doutr2
+
w_out
;
const
float
*
dr0
=
din_ch_ptr
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
dr5
=
dr4
+
w_in
;
const
float
*
din_ptr0
=
dr0
;
const
float
*
din_ptr1
=
dr1
;
const
float
*
din_ptr2
=
dr2
;
const
float
*
din_ptr3
=
dr3
;
const
float
*
din_ptr4
=
dr4
;
const
float
*
din_ptr5
=
dr5
;
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
4
)
{
//! process top pad pad_h = 1
din_ptr0
=
dr0
;
din_ptr1
=
dr1
;
din_ptr2
=
dr2
;
din_ptr3
=
dr3
;
din_ptr4
=
dr4
;
din_ptr5
=
dr5
;
doutr0
=
dout_ptr
;
doutr1
=
doutr0
+
w_out
;
doutr2
=
doutr1
+
w_out
;
doutr3
=
doutr2
+
w_out
;
if
(
i
==
0
)
{
din_ptr0
=
zero_ptr
;
din_ptr1
=
dr0
;
din_ptr2
=
dr1
;
din_ptr3
=
dr2
;
din_ptr4
=
dr3
;
din_ptr5
=
dr4
;
dr0
=
dr3
;
dr1
=
dr4
;
dr2
=
dr5
;
}
else
{
dr0
=
dr4
;
dr1
=
dr5
;
dr2
=
dr1
+
w_in
;
}
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
dr5
=
dr4
+
w_in
;
//! process bottom pad
if
(
i
+
5
>
h_in
)
{
switch
(
i
+
5
-
h_in
)
{
case
5
:
din_ptr1
=
zero_ptr
;
case
4
:
din_ptr2
=
zero_ptr
;
case
3
:
din_ptr3
=
zero_ptr
;
case
2
:
din_ptr4
=
zero_ptr
;
case
1
:
din_ptr5
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
4
>
h_out
)
{
switch
(
i
+
4
-
h_out
)
{
case
3
:
doutr1
=
write_ptr
;
case
2
:
doutr2
=
write_ptr
;
case
1
:
doutr3
=
write_ptr
;
default:
break
;
}
}
int
cnt
=
cnt_col
;
asm
volatile
(
"PRFM PLDL1KEEP, [%[din_ptr0]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr1]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr2]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr3]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr4]]
\n
"
"PRFM PLDL1KEEP, [%[din_ptr5]]
\n
"
"movi v21.4s, #0x0
\n
"
/* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ext v16.16b, %[vzero].16b, v0.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
// left
// r0
"fmla v12.4s, v0.4s, %[w0].s[1]
\n
"
/* outr00 += din0_0123 *
w0[1]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"sub %[din_ptr0], %[din_ptr0], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr1], %[din_ptr1], #4
\n
"
/* din_ptr0-- */
"fmla v12.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0012 *
w0[0]*/
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"sub %[din_ptr2], %[din_ptr2], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr3], %[din_ptr3], #4
\n
"
/* din_ptr0-- */
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_1234 *
w0[2]*/
"ext v16.16b, %[vzero].16b, v2.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[1]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v2.4s, %[w1].s[1]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"sub %[din_ptr4], %[din_ptr4], #4
\n
"
/* din_ptr0-- */
"sub %[din_ptr5], %[din_ptr5], #4
\n
"
/* din_ptr0-- */
"fmla v13.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v4.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[1]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v4.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v12.4s , v4.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v6.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[1]
\n
"
/*outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v6.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v13.4s , v6.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v8.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmla v14.4s , v8.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w2[1]*/
"fmax v12.4s, v12.4s, %[vzero].4s
\n
"
/*relu*/
"fmax v13.4s, v13.4s, %[vzero].4s
\n
"
/*relu*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w1].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
/* vst1q_f32() */
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
/* vst1q_f32() */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v10.16b, #12
\n
"
/* v16 = 00123*/
"ext v17.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234 */
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
// r5
"fmla v15.4s , v10.4s, %[w2].s[1]
\n
"
/* outr00 += din2_0123 *
w1[1]*/
"fmax v14.4s, v14.4s, %[vzero].4s
\n
"
/*relu*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w2].s[0]
\n
"
/* outr00 += din2_0123 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
/* vst1q_f32() */
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din1_0123 *
w0[1]*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"fmax v15.4s, v15.4s, %[vzero].4s
\n
"
/*relu*/
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
/* vst1q_f32() */
"cmp %[cnt], #1
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"blt 3f
\n
"
// mid
"1:
\n
"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v1.4s}, [%[din_ptr0]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]]
\n
"
/*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmax v12.4s, v12.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"ld1 {v7.4s}, [%[din_ptr3]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmax v13.4s, v13.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v13.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"fmax v14.4s, v14.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
"subs %[cnt], %[cnt], #1
\n
"
"fmax v15.4s, v15.4s, %[vzero].4s
\n
"
/*relu*/
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
"ld1 {v15.4s}, [%[bias_val]]
\n
"
/*vdupq_n_f32(bias_val)*/
"bne 1b
\n
"
// right
"3:
\n
"
"ld1 {v18.4s, v19.4s}, [%[vmask]]
\n
"
"ld1 {v22.4s}, [%[doutr0]]
\n
"
"ld1 {v23.4s}, [%[doutr1]]
\n
"
"ld1 {v24.4s}, [%[doutr2]]
\n
"
"ld1 {v25.4s}, [%[doutr3]]
\n
"
"bif v0.16b, %[vzero].16b, v18.16b
\n
"
"bif v1.16b, %[vzero].16b, v19.16b
\n
"
"bif v2.16b, %[vzero].16b, v18.16b
\n
"
"bif v3.16b, %[vzero].16b, v19.16b
\n
"
"bif v4.16b, %[vzero].16b, v18.16b
\n
"
"bif v5.16b, %[vzero].16b, v19.16b
\n
"
"bif v6.16b, %[vzero].16b, v18.16b
\n
"
"bif v7.16b, %[vzero].16b, v19.16b
\n
"
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v16 = 2345 */
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"bif v8.16b, %[vzero].16b, v18.16b
\n
"
"bif v9.16b, %[vzero].16b, v19.16b
\n
"
"bif v10.16b, %[vzero].16b, v18.16b
\n
"
"bif v11.16b, %[vzero].16b, v19.16b
\n
"
"fmla v12.4s, v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"ld1 {v18.4s}, [%[rmask]]
\n
"
"fmla v12.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8
\n
"
/* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8
\n
"
/* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w0].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v12.16b, v22.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w0].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8
\n
"
/* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"st1 {v12.4s}, [%[doutr0]], #16
\n
"
"fmax v13.4s, v13.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w1].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v13.16b, v23.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w1].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4
\n
"
/* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8
\n
"
/* v16 = 2345 */
"st1 {v13.4s}, [%[doutr1]], #16
\n
"
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]
\n
"
/* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s
\n
"
/*relu*/
"fmla v15.4s , v16.4s, %[w2].s[1]
\n
"
/* outr00 += din0_1234 *
w0[1]*/
"bif v14.16b, v24.16b, v18.16b
\n
"
"fmla v15.4s , v17.4s, %[w2].s[2]
\n
"
/* outr00 += din0_2345 *
w0[2]*/
"st1 {v14.4s}, [%[doutr2]], #16
\n
"
"fmax v15.4s, v15.4s, %[vzero].4s
\n
"
/*relu*/
"bif v15.16b, v25.16b, v18.16b
\n
"
"st1 {v15.4s}, [%[doutr3]], #16
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
}
#else
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
din0_ptr
=
nullptr
;
const
float
*
din1_ptr
=
nullptr
;
const
float
*
din2_ptr
=
nullptr
;
const
float
*
din3_ptr
=
nullptr
;
float
*
doutr0
=
nullptr
;
float
*
doutr1
=
nullptr
;
float
*
ptr_zero
=
const_cast
<
float
*>
(
zero
);
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
2
)
{
//! process top pad pad_h = 1
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
doutr0
=
dout_channel
;
doutr1
=
dout_channel
+
w_out
;
// unsigned int* rst_mask = rmask;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
din3_ptr
=
dr2
;
dr0
=
dr1
;
dr1
=
dr2
;
dr2
=
dr3
;
dr3
=
dr2
+
w_in
;
}
else
{
dr0
=
dr2
;
dr1
=
dr3
;
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
}
//! process bottom pad
if
(
i
+
3
>
h_in
)
{
switch
(
i
+
3
-
h_in
)
{
case
3
:
din1_ptr
=
zero_ptr
;
case
2
:
din2_ptr
=
zero_ptr
;
case
1
:
din3_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process bottom remain
if
(
i
+
2
>
h_out
)
{
doutr1
=
write_ptr
;
}
int
cnt
=
cnt_col
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
asm
volatile
(
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1
\n
"
"vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2
\n
"
"vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"vext.32 q6, %q[vzero], q8, #3 @ 0012
\n
"
"vext.32 q7, q8, q9, #1 @ 1234
\n
"
// left
// r0
"vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"sub %[din0_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #12 @ 1pad + 2 float data overlap
\n
"
"vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q10, #3 @ 0012
\n
"
"vext.32 q7, q10, q11, #1 @ 1234
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q12, #3 @ 0012
\n
"
"vext.32 q7, q12, q13, #1 @ 1234
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, %q[vzero], q14, #3 @ 0012
\n
"
"vext.32 q7, q14, q15, #1 @ 1234
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"cmp %[cnt], #1 @ check whether has "
"mid cols
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q5
// =
// vbias
"blt 3f @ jump to main loop start "
"point
\n
"
// mid
"1: @ right pad entry
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"pld [%[din3_ptr]] @ preload data
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
"vdup.32 q4, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], #1 @ loop count minus 1
\n
"
"vdup.32 q5, %[bias_val] @ and
\n
"
// q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point
\n
"
// right
"3: @ right pad entry
\n
"
"vld1.32 {d19}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d27}, [%[vmask]]! @ load din r0
\n
"
"vld1.32 {d31}, [%[vmask]]! @ load din r0
\n
"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vext.32 q6, q8, q9, #1 @ 1234
\n
"
"vext.32 q7, q8, q9, #2 @ 2345
\n
"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]
\n
"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad
\n
"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad
\n
"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad
\n
"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]
\n
"
"vext.32 q6, q10, q11, #1 @ 1234
\n
"
"vext.32 q7, q10, q11, #2 @ 2345
\n
"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d19}, [%[rmask]]! @ load din r0
\n
"
"vld1.32 {d23}, [%[rmask]]! @ load din r0
\n
"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0
\n
"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0
\n
"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q12, q13, #1 @ 1234
\n
"
"vext.32 q7, q12, q13, #2 @ 2345
\n
"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]
\n
"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]
\n
"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]
\n
"
"vext.32 q6, q14, q15, #1 @ 1234
\n
"
"vext.32 q7, q14, q15, #2 @ 2345
\n
"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]
\n
"
"vmax.f32 q4, q4, %q[vzero] @ relu
\n
"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]
\n
"
"vbif d8, d16, d19 @ bit select, deal with right pad
\n
"
"vbif d9, d17, d23 @ bit select, deal with right pad
\n
"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]
\n
"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer
\n
"
"vmax.f32 q5, q5, %q[vzero] @ relu
\n
"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad
\n
"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad
\n
"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
din3_ptr
]
"+r"
(
din3_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
dout_channel
+=
2
*
w_out
;
}
//! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, with reulu
*/
// w_in > 7
void
conv_depthwise_3x3s2p1_bias_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
int
size_pad_bottom
=
h_out
*
2
-
h_in
;
int
cnt_col
=
(
w_out
>>
2
)
-
2
;
int
size_right_remain
=
w_in
-
(
7
+
cnt_col
*
8
);
if
(
size_right_remain
>=
9
)
{
cnt_col
++
;
size_right_remain
-=
8
;
}
int
cnt_remain
=
(
size_right_remain
==
8
)
?
4
:
(
w_out
%
4
);
//
int
size_right_pad
=
w_out
*
2
-
w_in
;
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
size_right_remain
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
uint32x4_t
wmask
=
vcgtq_s32
(
vdupq_n_s32
(
cnt_remain
),
vld1q_s32
(
out_pad_idx
));
// 0 1 2 3
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
float
*
zero_ptr
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
w_in
*
sizeof
(
float
)));
memset
(
zero_ptr
,
0
,
w_in
*
sizeof
(
float
));
float
*
write_ptr
=
zero_ptr
+
w_in
;
unsigned
int
dmask
[
12
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
vst1q_u32
(
dmask
+
8
,
wmask
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
float32x4_t
wbias
;
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
bias_c
=
bias
[
i
];
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
const
float
*
dr0
=
din_channel
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
const
float
*
dr4
=
dr3
+
w_in
;
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
const
float
*
din3_ptr
=
dr3
;
const
float
*
din4_ptr
=
dr4
;
float
*
doutr0
=
dout_channel
;
float
*
doutr0_ptr
=
nullptr
;
float
*
doutr1_ptr
=
nullptr
;
#ifdef __aarch64__
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
4
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
din3_ptr
=
dr3
;
din4_ptr
=
dr4
;
doutr0_ptr
=
doutr0
;
doutr1_ptr
=
doutr0
+
w_out
;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
din3_ptr
=
dr2
;
din4_ptr
=
dr3
;
dr0
=
dr3
;
dr1
=
dr4
;
}
else
{
dr0
=
dr4
;
dr1
=
dr0
+
w_in
;
}
dr2
=
dr1
+
w_in
;
dr3
=
dr2
+
w_in
;
dr4
=
dr3
+
w_in
;
//! process bottom pad
if
(
i
+
4
>
h_in
)
{
switch
(
i
+
4
-
h_in
)
{
case
4
:
din1_ptr
=
zero_ptr
;
case
3
:
din2_ptr
=
zero_ptr
;
case
2
:
din3_ptr
=
zero_ptr
;
case
1
:
din4_ptr
=
zero_ptr
;
default:
break
;
}
}
//! process output pad
if
(
i
/
2
+
2
>
h_out
)
{
doutr1_ptr
=
write_ptr
;
}
int
cnt
=
cnt_col
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"prfm pldl1keep, [%[inptr0]]
\n
"
"prfm pldl1keep, [%[inptr1]]
\n
"
"prfm pldl1keep, [%[inptr2]]
\n
"
"prfm pldl1keep, [%[inptr3]]
\n
"
"prfm pldl1keep, [%[inptr4]]
\n
"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"ext v10.16b, %[vzero].16b, v1.16b, #12
\n
"
// v10 = {0,1,3,5}
// r0
"fmul v11.4s, v0.4s, %[w0].s[1]
\n
"
// {0,2,4,6} * w01
"fmul v12.4s, v1.4s, %[w0].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w0].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v3.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr0], %[inptr0], #4
\n
"
"sub %[inptr1], %[inptr1], #4
\n
"
// r1
"fmla v11.4s, v2.4s, %[w1].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v12.4s, v3.4s, %[w1].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w1].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v5.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr2], %[inptr2], #4
\n
"
"sub %[inptr3], %[inptr3], #4
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v11.4s, v4.4s, %[w2].s[1]
\n
"
// {0,2,4,6} * w01
"fmul v14.4s, v5.4s, %[w0].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v12.4s, v5.4s, %[w2].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w0].s[0]
\n
"
// {0,1,3,5} * w00
"fmla v16.4s, v10.4s, %[w2].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v7.16b, #12
\n
"
// v10 = {0,1,3,5}
"sub %[inptr4], %[inptr4], #4
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v14.4s, v7.4s, %[w1].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w1].s[0]
\n
"
// {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v9.16b, #12
\n
"
// v10 = {0,1,3,5}
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[1]
\n
"
// {0,2,4,6} * w01
"fmla v14.4s, v9.4s, %[w2].s[2]
\n
"
// {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w2].s[0]
\n
"
// {0,1,3,5} * w00
"fmax v16.4s, v16.4s, %[vzero].4s
\n
"
/* relu */
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
"fadd v17.4s, v17.4s, v13.4s
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"fmax v17.4s, v17.4s, %[vzero].4s
\n
"
/* relu */
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"cmp %[cnt], #1
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"blt 1f
\n
"
// mid
"2:
\n
"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32
\n
"
// v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32
\n
"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32
\n
"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4
\n
"
// v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32
\n
"
"ld1 {v15.4s}, [%[inptr0]]
\n
"
"ld1 {v18.4s}, [%[inptr1]]
\n
"
"fmax v16.4s, v16.4s, %[vzero].4s
\n
"
/* relu */
"fadd v17.4s, v17.4s, v13.4s
\n
"
"ld1 {v19.4s}, [%[inptr2]]
\n
"
"ld1 {v20.4s}, [%[inptr3]]
\n
"
"ld1 {v21.4s}, [%[inptr4]]
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fadd v17.4s, v17.4s, v14.4s
\n
"
"ext v10.16b, v0.16b, v15.16b, #4
\n
"
// v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b
\n
"
// v10 = vbias
"subs %[cnt], %[cnt], #1
\n
"
"fmax v17.4s, v17.4s, %[vzero].4s
\n
"
/* relu */
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"and v17.16b, %[vbias].16b, %[vbias].16b
\n
"
// v16 = vbias
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 4f
\n
"
"3:
\n
"
"bif v0.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b
\n
"
// pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b
\n
"
// pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2]
\n
"
// {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2]
\n
"
// {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4
\n
"
// v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]]
\n
"
"fadd v16.4s, v16.4s, v11.4s
\n
"
"fadd v16.4s, v16.4s, v12.4s
\n
"
"ld1 {v1.4s}, [%[outptr1]]
\n
"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0]
\n
"
// {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1]
\n
"
// {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2]
\n
"
// {2,4,6,8} * w02
"fmax v16.4s, v16.4s, %[vzero].4s
\n
"
/* relu */
"fadd v17.4s, v17.4s, v13.4s
\n
"
"bif v16.16b, v0.16b, %[wmask].16b
\n
"
// pipei
"fadd v17.4s, v17.4s, v14.4s
\n
"
"st1 {v16.4s}, [%[outptr0]], #16
\n
"
"fmax v17.4s, v17.4s, %[vzero].4s
\n
"
/* relu */
"bif v17.16b, v1.16b, %[wmask].16b
\n
"
// pipei
"st1 {v17.4s}, [%[outptr1]], #16
\n
"
"4:
\n
"
:
[
inptr0
]
"+r"
(
din0_ptr
),
[
inptr1
]
"+r"
(
din1_ptr
),
[
inptr2
]
"+r"
(
din2_ptr
),
[
inptr3
]
"+r"
(
din3_ptr
),
[
inptr4
]
"+r"
(
din4_ptr
),
[
outptr0
]
"+r"
(
doutr0_ptr
),
[
outptr1
]
"+r"
(
doutr1_ptr
),
[
cnt
]
"+r"
(
cnt
)
:
[
vzero
]
"w"
(
vzero
),
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
remain
]
"r"
(
cnt_remain
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
wmask
]
"w"
(
wmask
),
[
vbias
]
"w"
(
wbias
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
);
doutr0
=
doutr0
+
2
*
w_out
;
}
#else
for
(
int
i
=
0
;
i
<
h_in
;
i
+=
2
)
{
din0_ptr
=
dr0
;
din1_ptr
=
dr1
;
din2_ptr
=
dr2
;
doutr0_ptr
=
doutr0
;
if
(
i
==
0
)
{
din0_ptr
=
zero_ptr
;
din1_ptr
=
dr0
;
din2_ptr
=
dr1
;
dr0
=
dr1
;
dr1
=
dr2
;
dr2
=
dr1
+
w_in
;
}
else
{
dr0
=
dr2
;
dr1
=
dr0
+
w_in
;
dr2
=
dr1
+
w_in
;
}
//! process bottom pad
if
(
i
+
2
>
h_in
)
{
switch
(
i
+
2
-
h_in
)
{
case
2
:
din1_ptr
=
zero_ptr
;
case
1
:
din2_ptr
=
zero_ptr
;
default:
break
;
}
}
int
cnt
=
cnt_col
;
unsigned
int
*
mask_ptr
=
dmask
;
asm
volatile
(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0:
\n
"
"vmov.u32 q9, #0
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1
\n
"
// v11={0,2,4,6} v12={1,3,5,7}, q10, q11
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v11={0,2,4,6} v12={1,3,5,7}, q12, q13
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
// v13={0,2,4,6} v14={1,3,5,7}, q14, q15
"pld [%[din0_ptr]] @ preload data
\n
"
"pld [%[din1_ptr]] @ preload data
\n
"
"pld [%[din2_ptr]] @ preload data
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vext.32 q6, q9, q11, #3 @ shift right 1 "
"data
\n
"
// q2 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift right 1 "
"data
\n
"
// q6 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift right 1 "
"data
\n
"
// q6 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, "
"out0
\n
"
// q11 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, "
"out0
\n
"
// q12 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, "
"out0
\n
"
// q6 * w00
"sub %[din0_ptr], #4 @ inpitr0 - 1
\n
"
"sub %[din1_ptr], #4 @ inpitr1 - 1
\n
"
"sub %[din2_ptr], #4 @ inpitr2 - 1
\n
"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q11 * w01
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q12 * w02
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w00
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, "
"out1
\n
"
// q0 * w01
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, "
"out1
\n
"
// q1 * w02
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, "
"out1
\n
"
// q2 * w00
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"cmp %[cnt], #1
\n
"
"blt 1f
\n
"
// mid
"2:
\n
"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0
\n
"
// q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vext.32 q6, q10, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1
\n
"
// q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"subs %[cnt], #1
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"bne 2b
\n
"
// right
"1:
\n
"
"cmp %[remain], #1
\n
"
"blt 3f
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q10 =
// vbias
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q10, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1
\n
"
// q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output
\n
"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask
\n
"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q6 * w02
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vbif.f32 q3, q10, q11 @ write mask
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"3:
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
outptr
]
"+r"
(
doutr0_ptr
),
[
cnt
]
"+r"
(
cnt
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
remain
]
"r"
(
cnt_remain
),
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
doutr0
=
doutr0
+
w_out
;
}
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void
conv_depthwise_3x3s1p1_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const
int
right_pad_idx
[
4
]
=
{
3
,
2
,
1
,
0
};
const
float
zero
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
uint32x4_t
vmask_rp
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
),
vdupq_n_s32
(
4
-
w_in
));
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
wbias
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
int
hs
=
-
1
;
int
he
=
3
;
float
out_buf1
[
4
];
float
out_buf2
[
4
];
float
trash_buf
[
4
];
int
h_cnt
=
(
h_out
+
1
)
>>
1
;
float
*
doutr0
=
dout_channel
;
float
*
doutr1
=
dout_channel
+
w_out
;
for
(
int
j
=
0
;
j
<
h_cnt
;
++
j
)
{
const
float
*
dr0
=
din_channel
+
hs
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
if
(
hs
==
-
1
)
{
dr0
=
zero
;
}
switch
(
he
-
h_in
)
{
case
2
:
dr2
=
zero
;
doutr1
=
trash_buf
;
case
1
:
dr3
=
zero
;
default:
break
;
}
#ifdef __aarch64__
asm
volatile
(
"prfm pldl1keep, [%[din0]]
\n
"
"prfm pldl1keep, [%[din1]]
\n
"
"prfm pldl1keep, [%[din2]]
\n
"
"prfm pldl1keep, [%[din3]]
\n
"
"ld1 {v0.4s}, [%[din0]], #16
\n
"
"ld1 {v1.4s}, [%[din1]], #16
\n
"
"ld1 {v2.4s}, [%[din2]], #16
\n
"
"ld1 {v3.4s}, [%[din3]], #16
\n
"
"bif v0.16b, %[zero].16b, %[mask].16b
\n
"
// d0_1234
"bif v1.16b, %[zero].16b, %[mask].16b
\n
"
// d1_1234
"bif v2.16b, %[zero].16b, %[mask].16b
\n
"
// d2_1234
"bif v3.16b, %[zero].16b, %[mask].16b
\n
"
// d3_1234
"ext v4.16b, %[zero].16b, v0.16b, #12
\n
"
// d0_0123
"ext v5.16b, %[zero].16b, v1.16b, #12
\n
"
// d1_0123
"ext v6.16b, %[zero].16b, v2.16b, #12
\n
"
// d2_0123
"ext v7.16b, %[zero].16b, v3.16b, #12
\n
"
// d3_0123
"ext v8.16b, v0.16b, %[zero].16b, #4
\n
"
// d0_2340
"ext v9.16b, v1.16b, %[zero].16b, #4
\n
"
// d1_2340
"ext v10.16b, v2.16b, %[zero].16b, #4
\n
"
// d2_2340
"ext v11.16b, v3.16b, %[zero].16b, #4
\n
"
// d3_2340
"fmul v12.4s, v0.4s, %[wr0].s[1]
\n
"
"fmul v13.4s, v1.4s, %[wr0].s[1]
\n
"
"fmul v14.4s, v1.4s, %[wr1].s[1]
\n
"
"fmul v15.4s, v2.4s, %[wr1].s[1]
\n
"
"fmul v16.4s, v2.4s, %[wr2].s[1]
\n
"
"fmul v17.4s, v3.4s, %[wr2].s[1]
\n
"
"fmla v12.4s, v4.4s, %[wr0].s[0]
\n
"
"fmla v13.4s, v5.4s, %[wr0].s[0]
\n
"
"fmla v14.4s, v5.4s, %[wr1].s[0]
\n
"
"fmla v15.4s, v6.4s, %[wr1].s[0]
\n
"
"fmla v16.4s, v6.4s, %[wr2].s[0]
\n
"
"fmla v17.4s, v7.4s, %[wr2].s[0]
\n
"
"fmla v12.4s, v8.4s, %[wr0].s[2]
\n
"
"fmla v13.4s, v9.4s, %[wr0].s[2]
\n
"
"fmla v14.4s, v9.4s, %[wr1].s[2]
\n
"
"fmla v15.4s, v10.4s, %[wr1].s[2]
\n
"
"fmla v16.4s, v10.4s, %[wr2].s[2]
\n
"
"fmla v17.4s, v11.4s, %[wr2].s[2]
\n
"
"fadd v12.4s, v12.4s, v14.4s
\n
"
"fadd v12.4s, v12.4s, v16.4s
\n
"
"fadd v13.4s, v13.4s, v15.4s
\n
"
// out1
"fadd v13.4s, v13.4s, v17.4s
\n
"
// out2
"fadd v12.4s, v12.4s, %[bias].4s
\n
"
// out1 add bias
"fadd v13.4s, v13.4s, %[bias].4s
\n
"
// out2 add bias
"prfm pldl1keep, [%[out1]]
\n
"
"prfm pldl1keep, [%[out2]]
\n
"
"st1 {v12.4s}, [%[out1]]
\n
"
"st1 {v13.4s}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
#else
asm
volatile
(
"pld [%[din0]]
\n
"
"pld [%[din1]]
\n
"
"pld [%[din2]]
\n
"
"pld [%[din3]]
\n
"
"vld1.32 {d12-d13}, [%[din0]]!
\n
"
"vld1.32 {d14-d15}, [%[din1]]!
\n
"
"vld1.32 {d16-d17}, [%[din2]]!
\n
"
"vld1.32 {d18-d19}, [%[din3]]!
\n
"
"vbif q6, %q[zero], %q[mask]
\n
"
// d0_1234
"vbif q7, %q[zero], %q[mask]
\n
"
// d1_1234
"vbif q8, %q[zero], %q[mask]
\n
"
// d2_1234
"vbif q9, %q[zero], %q[mask]
\n
"
// d3_1234
"vmul.f32 q14, q6, %e[wr0][1]
\n
"
"vmul.f32 q15, q7, %e[wr0][1]
\n
"
"vmla.f32 q14, q7, %e[wr1][1]
\n
"
"vmla.f32 q15, q8, %e[wr1][1]
\n
"
"vmla.f32 q14, q8, %e[wr2][1]
\n
"
"vmla.f32 q15, q9, %e[wr2][1]
\n
"
"vext.32 q10, %q[zero], q6, #3
\n
"
// d0_0123
"vext.32 q11, %q[zero], q7, #3
\n
"
// d1_0123
"vext.32 q12, %q[zero], q8, #3
\n
"
// d2_0123
"vext.32 q13, %q[zero], q9, #3
\n
"
// d3_0123
"vmla.f32 q14, q10, %e[wr0][0]
\n
"
"vmla.f32 q15, q11, %e[wr0][0]
\n
"
"vmla.f32 q14, q11, %e[wr1][0]
\n
"
"vmla.f32 q15, q12, %e[wr1][0]
\n
"
"vmla.f32 q14, q12, %e[wr2][0]
\n
"
"vmla.f32 q15, q13, %e[wr2][0]
\n
"
"vext.32 q10, q6, %q[zero], #1
\n
"
// d0_2340
"vext.32 q11, q7, %q[zero], #1
\n
"
// d1_2340
"vext.32 q12, q8, %q[zero], #1
\n
"
// d2_2340
"vext.32 q13, q9, %q[zero], #1
\n
"
// d3_2340
"vmla.f32 q14, q10, %f[wr0][0]
\n
"
"vmla.f32 q15, q11, %f[wr0][0]
\n
"
"vmla.f32 q14, q11, %f[wr1][0]
\n
"
"vmla.f32 q15, q12, %f[wr1][0]
\n
"
"vmla.f32 q14, q12, %f[wr2][0]
\n
"
// out1
"vmla.f32 q15, q13, %f[wr2][0]
\n
"
// out2
"vadd.f32 q14, q14, %q[bias]
\n
"
// out1 add bias
"vadd.f32 q15, q15, %q[bias]
\n
"
// out2 add bias
"pld [%[out1]]
\n
"
"pld [%[out2]]
\n
"
"vst1.32 {d28-d29}, [%[out1]]
\n
"
"vst1.32 {d30-d31}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
};
doutr0
=
doutr1
;
doutr1
+=
w_out
;
hs
+=
2
;
he
+=
2
;
}
// end of processing heights
}
// end of processing channels
}
// end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void
conv_depthwise_3x3s2p1_bias_s
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
float
zeros
[
8
]
=
{
0.0
f
};
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
unsigned
int
dmask
[
8
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
bias_c
=
bias
[
i
];
}
float32x4_t
vbias
=
vdupq_n_f32
(
bias_c
);
int
hs
=
-
1
;
int
he
=
2
;
float
out_buf
[
4
];
for
(
int
j
=
0
;
j
<
h_out
;
++
j
)
{
const
float
*
dr0
=
din_channel
+
hs
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
if
(
hs
==
-
1
)
{
dr0
=
zeros
;
}
if
(
he
>
h_in
)
{
dr2
=
zeros
;
}
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
unsigned
int
*
mask_ptr
=
dmask
;
#ifdef __aarch64__
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0
\n
"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32
\n
"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32
\n
"
// v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32
\n
"
// v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32
\n
"
// v14={0,2,4,6}
// v15={1,3,5,7}
"bif v10.16b, v9.16b, v6.16b
\n
"
"bif v11.16b, v9.16b, v7.16b
\n
"
"bif v12.16b, v9.16b, v6.16b
\n
"
"bif v13.16b, v9.16b, v7.16b
\n
"
"bif v14.16b, v9.16b, v6.16b
\n
"
"bif v15.16b, v9.16b, v7.16b
\n
"
"ext v6.16b, v9.16b, v11.16b, #12
\n
"
// v6 =
// {0,1,3,5}
"ext v7.16b, v9.16b, v13.16b, #12
\n
"
// v7 =
// {0,1,3,5}
"ext v8.16b, v9.16b, v15.16b, #12
\n
"
// v8 =
// {0,1,3,5}
"fmul v4.4s, v10.4s, %[wr0].s[1]
\n
"
// v10 * w01
"fmul v5.4s, v11.4s, %[wr0].s[2]
\n
"
// v11 * w02
"fmul v6.4s, v6.4s, %[wr0].s[0]
\n
"
// v6 * w00
"fmla v4.4s, v12.4s, %[wr1].s[1]
\n
"
// v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[2]
\n
"
// v13 * w12
"fmla v6.4s, v7.4s, %[wr1].s[0]
\n
"
// v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[1]
\n
"
// v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[2]
\n
"
// v15 * w21
"fmla v6.4s, v8.4s, %[wr2].s[0]
\n
"
// v8 * w22
"fadd v4.4s, v4.4s, v5.4s
\n
"
"fadd v4.4s, v4.4s, v6.4s
\n
"
"fadd v4.4s, v4.4s, %[bias].4s
\n
"
"st1 {v4.4s}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"w"
(
vbias
),
[
out
]
"r"
(
out_buf
)
:
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
#else
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q9, q11, #3 @ shift left 1
\n
"
// q6 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift left 1
\n
"
// q7 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift left 1
\n
"
// q8 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q10 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q11 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w00
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q12 * w11
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q13 * w12
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q7 * w10
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q14 * w20
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q15 * w21
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q8 * w22
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vst1.32 {d6-d7}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
),
[
out
]
"r"
(
out_buf
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
dout_channel
++
=
out_buf
[
w
];
}
hs
+=
2
;
he
+=
2
;
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void
conv_depthwise_3x3s1p1_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const
int
right_pad_idx
[
4
]
=
{
3
,
2
,
1
,
0
};
const
float
zero
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
uint32x4_t
vmask_rp
=
vcgeq_s32
(
vld1q_s32
(
right_pad_idx
),
vdupq_n_s32
(
4
-
w_in
));
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float32x4_t
wbias
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
int
hs
=
-
1
;
int
he
=
3
;
float
out_buf1
[
4
];
float
out_buf2
[
4
];
float
trash_buf
[
4
];
int
h_cnt
=
(
h_out
+
1
)
>>
1
;
float
*
doutr0
=
dout_channel
;
float
*
doutr1
=
dout_channel
+
w_out
;
for
(
int
j
=
0
;
j
<
h_cnt
;
++
j
)
{
const
float
*
dr0
=
din_channel
+
hs
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
const
float
*
dr3
=
dr2
+
w_in
;
if
(
hs
==
-
1
)
{
dr0
=
zero
;
}
switch
(
he
-
h_in
)
{
case
2
:
dr2
=
zero
;
doutr1
=
trash_buf
;
case
1
:
dr3
=
zero
;
default:
break
;
}
#ifdef __aarch64__
asm
volatile
(
"prfm pldl1keep, [%[din0]]
\n
"
"prfm pldl1keep, [%[din1]]
\n
"
"prfm pldl1keep, [%[din2]]
\n
"
"prfm pldl1keep, [%[din3]]
\n
"
"ld1 {v0.4s}, [%[din0]], #16
\n
"
"ld1 {v1.4s}, [%[din1]], #16
\n
"
"ld1 {v2.4s}, [%[din2]], #16
\n
"
"ld1 {v3.4s}, [%[din3]], #16
\n
"
"bif v0.16b, %[zero].16b, %[mask].16b
\n
"
// d0_1234
"bif v1.16b, %[zero].16b, %[mask].16b
\n
"
// d1_1234
"bif v2.16b, %[zero].16b, %[mask].16b
\n
"
// d2_1234
"bif v3.16b, %[zero].16b, %[mask].16b
\n
"
// d3_1234
"ext v4.16b, %[zero].16b, v0.16b, #12
\n
"
// d0_0123
"ext v5.16b, %[zero].16b, v1.16b, #12
\n
"
// d1_0123
"ext v6.16b, %[zero].16b, v2.16b, #12
\n
"
// d2_0123
"ext v7.16b, %[zero].16b, v3.16b, #12
\n
"
// d3_0123
"ext v8.16b, v0.16b, %[zero].16b, #4
\n
"
// d0_2340
"ext v9.16b, v1.16b, %[zero].16b, #4
\n
"
// d1_2340
"ext v10.16b, v2.16b, %[zero].16b, #4
\n
"
// d2_2340
"ext v11.16b, v3.16b, %[zero].16b, #4
\n
"
// d3_2340
"fmul v12.4s, v0.4s, %[wr0].s[1]
\n
"
"fmul v13.4s, v1.4s, %[wr0].s[1]
\n
"
"fmul v14.4s, v1.4s, %[wr1].s[1]
\n
"
"fmul v15.4s, v2.4s, %[wr1].s[1]
\n
"
"fmul v16.4s, v2.4s, %[wr2].s[1]
\n
"
"fmul v17.4s, v3.4s, %[wr2].s[1]
\n
"
"fmla v12.4s, v4.4s, %[wr0].s[0]
\n
"
"fmla v13.4s, v5.4s, %[wr0].s[0]
\n
"
"fmla v14.4s, v5.4s, %[wr1].s[0]
\n
"
"fmla v15.4s, v6.4s, %[wr1].s[0]
\n
"
"fmla v16.4s, v6.4s, %[wr2].s[0]
\n
"
"fmla v17.4s, v7.4s, %[wr2].s[0]
\n
"
"fmla v12.4s, v8.4s, %[wr0].s[2]
\n
"
"fmla v13.4s, v9.4s, %[wr0].s[2]
\n
"
"fmla v14.4s, v9.4s, %[wr1].s[2]
\n
"
"fmla v15.4s, v10.4s, %[wr1].s[2]
\n
"
"fmla v16.4s, v10.4s, %[wr2].s[2]
\n
"
"fmla v17.4s, v11.4s, %[wr2].s[2]
\n
"
"fadd v12.4s, v12.4s, v14.4s
\n
"
"fadd v12.4s, v12.4s, v16.4s
\n
"
"fadd v13.4s, v13.4s, v15.4s
\n
"
// out1
"fadd v13.4s, v13.4s, v17.4s
\n
"
// out2
"fadd v12.4s, v12.4s, %[bias].4s
\n
"
// out1 add bias
"fadd v13.4s, v13.4s, %[bias].4s
\n
"
// out2 add bias
"prfm pldl1keep, [%[out1]]
\n
"
"prfm pldl1keep, [%[out2]]
\n
"
"fmax v12.4s, v12.4s, %[zero].4s
\n
"
// out1 -> relu
"fmax v13.4s, v13.4s, %[zero].4s
\n
"
// out2 -> relu
"st1 {v12.4s}, [%[out1]]
\n
"
"st1 {v13.4s}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
#else
asm
volatile
(
"pld [%[din0]]
\n
"
"pld [%[din1]]
\n
"
"pld [%[din2]]
\n
"
"pld [%[din3]]
\n
"
"vld1.32 {d12-d13}, [%[din0]]!
\n
"
"vld1.32 {d14-d15}, [%[din1]]!
\n
"
"vld1.32 {d16-d17}, [%[din2]]!
\n
"
"vld1.32 {d18-d19}, [%[din3]]!
\n
"
"vbif q6, %q[zero], %q[mask]
\n
"
// d0_1234
"vbif q7, %q[zero], %q[mask]
\n
"
// d1_1234
"vbif q8, %q[zero], %q[mask]
\n
"
// d2_1234
"vbif q9, %q[zero], %q[mask]
\n
"
// d3_1234
"vmul.f32 q14, q6, %e[wr0][1]
\n
"
"vmul.f32 q15, q7, %e[wr0][1]
\n
"
"vmla.f32 q14, q7, %e[wr1][1]
\n
"
"vmla.f32 q15, q8, %e[wr1][1]
\n
"
"vmla.f32 q14, q8, %e[wr2][1]
\n
"
"vmla.f32 q15, q9, %e[wr2][1]
\n
"
"vext.32 q10, %q[zero], q6, #3
\n
"
// d0_0123
"vext.32 q11, %q[zero], q7, #3
\n
"
// d1_0123
"vext.32 q12, %q[zero], q8, #3
\n
"
// d2_0123
"vext.32 q13, %q[zero], q9, #3
\n
"
// d3_0123
"vmla.f32 q14, q10, %e[wr0][0]
\n
"
"vmla.f32 q15, q11, %e[wr0][0]
\n
"
"vmla.f32 q14, q11, %e[wr1][0]
\n
"
"vmla.f32 q15, q12, %e[wr1][0]
\n
"
"vmla.f32 q14, q12, %e[wr2][0]
\n
"
"vmla.f32 q15, q13, %e[wr2][0]
\n
"
"vext.32 q10, q6, %q[zero], #1
\n
"
// d0_2340
"vext.32 q11, q7, %q[zero], #1
\n
"
// d1_2340
"vext.32 q12, q8, %q[zero], #1
\n
"
// d2_2340
"vext.32 q13, q9, %q[zero], #1
\n
"
// d3_2340
"vmla.f32 q14, q10, %f[wr0][0]
\n
"
"vmla.f32 q15, q11, %f[wr0][0]
\n
"
"vmla.f32 q14, q11, %f[wr1][0]
\n
"
"vmla.f32 q15, q12, %f[wr1][0]
\n
"
"vmla.f32 q14, q12, %f[wr2][0]
\n
"
// out1
"vmla.f32 q15, q13, %f[wr2][0]
\n
"
// out2
"vadd.f32 q14, q14, %q[bias]
\n
"
// out1 add bias
"vadd.f32 q15, q15, %q[bias]
\n
"
// out2 add bias
"pld [%[out1]]
\n
"
"pld [%[out2]]
\n
"
"vmax.f32 q14, q14, %q[zero]
\n
"
// out1 -> relu
"vmax.f32 q15, q15, %q[zero]
\n
"
// out2 -> relu
"vst1.32 {d28-d29}, [%[out1]]
\n
"
"vst1.32 {d30-d31}, [%[out2]]
\n
"
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
};
doutr0
=
doutr1
;
doutr1
+=
w_out
;
hs
+=
2
;
he
+=
2
;
}
// end of processing heights
}
// end of processing channels
}
// end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 7
*/
void
conv_depthwise_3x3s2p1_bias_s_relu
(
float
*
dout
,
const
float
*
din
,
const
float
*
weights
,
const
float
*
bias
,
bool
flag_bias
,
const
int
num
,
const
int
ch_in
,
const
int
h_in
,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
)
{
int
right_pad_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
int
out_pad_idx
[
4
]
=
{
0
,
1
,
2
,
3
};
float
zeros
[
8
]
=
{
0.0
f
};
uint32x4_t
vmask_rp1
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
));
// 0 2 4 6
uint32x4_t
vmask_rp2
=
vcgtq_s32
(
vdupq_n_s32
(
w_in
),
vld1q_s32
(
right_pad_idx
+
4
));
// 1 3 5 7
int
size_in_channel
=
w_in
*
h_in
;
int
size_out_channel
=
w_out
*
h_out
;
unsigned
int
dmask
[
8
];
vst1q_u32
(
dmask
,
vmask_rp1
);
vst1q_u32
(
dmask
+
4
,
vmask_rp2
);
for
(
int
n
=
0
;
n
<
num
;
++
n
)
{
const
float
*
din_batch
=
din
+
n
*
ch_in
*
size_in_channel
;
float
*
dout_batch
=
dout
+
n
*
ch_in
*
size_out_channel
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
ch_in
;
++
i
)
{
const
float
*
din_channel
=
din_batch
+
i
*
size_in_channel
;
float
*
dout_channel
=
dout_batch
+
i
*
size_out_channel
;
const
float
*
weight_ptr
=
weights
+
i
*
9
;
float32x4_t
wr0
=
vld1q_f32
(
weight_ptr
);
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
float
bias_c
=
0.
f
;
if
(
flag_bias
)
{
bias_c
=
bias
[
i
];
}
float32x4_t
vbias
=
vdupq_n_f32
(
bias_c
);
int
hs
=
-
1
;
int
he
=
2
;
float
out_buf
[
4
];
for
(
int
j
=
0
;
j
<
h_out
;
++
j
)
{
const
float
*
dr0
=
din_channel
+
hs
*
w_in
;
const
float
*
dr1
=
dr0
+
w_in
;
const
float
*
dr2
=
dr1
+
w_in
;
if
(
hs
==
-
1
)
{
dr0
=
zeros
;
}
if
(
he
>
h_in
)
{
dr2
=
zeros
;
}
const
float
*
din0_ptr
=
dr0
;
const
float
*
din1_ptr
=
dr1
;
const
float
*
din2_ptr
=
dr2
;
unsigned
int
*
mask_ptr
=
dmask
;
#ifdef __aarch64__
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0
\n
"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32
\n
"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32
\n
"
// v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32
\n
"
// v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32
\n
"
// v14={0,2,4,6}
// v15={1,3,5,7}
"bif v10.16b, v9.16b, v6.16b
\n
"
"bif v11.16b, v9.16b, v7.16b
\n
"
"bif v12.16b, v9.16b, v6.16b
\n
"
"bif v13.16b, v9.16b, v7.16b
\n
"
"bif v14.16b, v9.16b, v6.16b
\n
"
"bif v15.16b, v9.16b, v7.16b
\n
"
"ext v6.16b, v9.16b, v11.16b, #12
\n
"
// v6 =
// {0,1,3,5}
"ext v7.16b, v9.16b, v13.16b, #12
\n
"
// v7 =
// {0,1,3,5}
"ext v8.16b, v9.16b, v15.16b, #12
\n
"
// v8 =
// {0,1,3,5}
"fmul v4.4s, v10.4s, %[wr0].s[1]
\n
"
// v10 * w01
"fmul v5.4s, v11.4s, %[wr0].s[2]
\n
"
// v11 * w02
"fmul v6.4s, v6.4s, %[wr0].s[0]
\n
"
// v6 * w00
"fmla v4.4s, v12.4s, %[wr1].s[1]
\n
"
// v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[2]
\n
"
// v13 * w12
"fmla v6.4s, v7.4s, %[wr1].s[0]
\n
"
// v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[1]
\n
"
// v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[2]
\n
"
// v15 * w21
"fmla v6.4s, v8.4s, %[wr2].s[0]
\n
"
// v8 * w22
"fadd v4.4s, v4.4s, v5.4s
\n
"
"fadd v4.4s, v4.4s, v6.4s
\n
"
"fadd v4.4s, v4.4s, %[bias].4s
\n
"
// out add bias
"fmax v4.4s, v4.4s, v9.4s
\n
"
"st1 {v4.4s}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"w"
(
vbias
),
[
out
]
"r"
(
out_buf
)
:
"cc"
,
"memory"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
#else
asm
volatile
(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0
\n
"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask
\n
"
"vdup.32 q3, %[bias] @ and
\n
"
// q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0
\n
"
// q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1
\n
"
// q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2
\n
"
// q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad
\n
"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad
\n
"
"vext.32 q6, q9, q11, #3 @ shift left 1
\n
"
// q6 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift left 1
\n
"
// q7 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift left 1
\n
"
// q8 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, "
"out0
\n
"
// q10 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, "
"out0
\n
"
// q11 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, "
"out0
\n
"
// q6 * w00
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0
\n
"
// q12 * w11
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0
\n
"
// q13 * w12
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0
\n
"
// q7 * w10
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, "
"out0
\n
"
// q14 * w20
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, "
"out0
\n
"
// q15 * w21
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, "
"out0
\n
"
// q8 * w22
"vadd.f32 q3, q3, q4 @ add
\n
"
"vadd.f32 q3, q3, q5 @ add
\n
"
"vmax.f32 q3, q3, q9 @ relu
\n
"
"vst1.32 {d6-d7}, [%[out]]
\n
"
:
[
din0_ptr
]
"+r"
(
din0_ptr
),
[
din1_ptr
]
"+r"
(
din1_ptr
),
[
din2_ptr
]
"+r"
(
din2_ptr
),
[
mask_ptr
]
"+r"
(
mask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias
]
"r"
(
bias_c
),
[
out
]
"r"
(
out_buf
)
:
"cc"
,
"memory"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif //__aarch64__
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
dout_channel
++
=
out_buf
[
w
];
}
hs
+=
2
;
he
+=
2
;
}
}
}
}
}
// namespace depthwise
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
#endif
test/net/test_benchmark.cpp
浏览文件 @
16a0bd75
...
...
@@ -59,12 +59,13 @@ int main(int argc, char* argv[]) {
paddle_mobile
.
Predict
(
input
);
}
auto
time3
=
time
();
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
int
test_count
=
100
;
for
(
int
i
=
0
;
i
<
test_count
;
++
i
)
{
paddle_mobile
.
Predict
(
input
);
}
auto
time4
=
time
();
std
::
cout
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
10
<<
"ms
\n
"
;
std
::
cout
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
test_count
<<
"ms
\n
"
;
std
::
ostringstream
os
(
"output tensor size: "
);
output
=
paddle_mobile
.
Fetch
();
os
<<
output
->
numel
()
<<
"
\n
"
<<
output
->
data
<
float
>
()[
0
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录