Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
3455ab0a
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3455ab0a
编写于
12月 17, 2019
作者:
H
HappyAngel
提交者:
yiicy
12月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[lite][arm] add conv+relu6/leakyRelu fusion (#2599)
上级
ec8353e8
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
4666 addition
and
1396 deletion
+4666
-1396
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
+12
-6
lite/backends/arm/math/conv3x3s1_direct_fp32.cc
lite/backends/arm/math/conv3x3s1_direct_fp32.cc
+5
-2
lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc
lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc
+2449
-565
lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc
lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc
+852
-349
lite/backends/arm/math/conv3x3s2_direct_fp32.cc
lite/backends/arm/math/conv3x3s2_direct_fp32.cc
+5
-2
lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc
lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc
+52
-75
lite/backends/arm/math/conv_block_utils.h
lite/backends/arm/math/conv_block_utils.h
+1246
-368
lite/backends/arm/math/conv_depthwise.h
lite/backends/arm/math/conv_depthwise.h
+2
-0
lite/backends/arm/math/conv_impl.cc
lite/backends/arm/math/conv_impl.cc
+3
-0
lite/kernels/arm/conv_compute.cc
lite/kernels/arm/conv_compute.cc
+1
-1
lite/operators/conv_op.cc
lite/operators/conv_op.cc
+28
-0
lite/operators/conv_op.h
lite/operators/conv_op.h
+0
-28
lite/tests/math/conv_compute_test.cc
lite/tests/math/conv_compute_test.cc
+11
-0
未找到文件。
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
浏览文件 @
3455ab0a
...
...
@@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
...
...
@@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
}
...
...
@@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
...
...
@@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
}
...
...
@@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
...
...
@@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout
,
wout
,
false
,
zero_ptr
);
zero_ptr
,
nullptr
);
}
}
}
...
...
lite/backends/arm/math/conv3x3s1_direct_fp32.cc
浏览文件 @
3455ab0a
...
...
@@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const
int
threads
=
ctx
->
threads
();
int
l2_size
=
ctx
->
llc_size
()
/
sizeof
(
float
);
auto
paddings
=
*
param
.
paddings
;
auto
act_param
=
param
.
activation_param
;
const
int
pad_h
=
paddings
[
0
];
const
int
pad_w
=
paddings
[
2
];
...
...
@@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh
,
ow
,
flag_relu
,
ptr_write
);
ptr_write
,
&
act_param
);
}
const
float
*
weight_remain_ptr
=
weights
+
c_round_down
*
w_stride
;
#pragma omp parallel for num_threads(threads)
...
...
@@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh
,
ow
,
flag_relu
,
ptr_write
);
ptr_write
,
&
act_param
);
}
}
}
...
...
lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc
浏览文件 @
3455ab0a
...
...
@@ -32,6 +32,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_depthwise_3x3s1p0_bias_s
(
float
*
dout
,
...
...
@@ -46,6 +47,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_depthwise_3x3s1p1_bias
(
float
*
dout
,
...
...
@@ -60,6 +62,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_depthwise_3x3s1p1_bias_s
(
float
*
dout
,
...
...
@@ -74,6 +77,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_depthwise_3x3s1_fp32
(
const
float
*
din
,
...
...
@@ -90,6 +94,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
int
pad
,
bool
flag_bias
,
bool
flag_relu
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
if
(
pad
==
0
)
{
if
(
w_in
>
5
)
{
...
...
@@ -105,6 +110,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in
,
h_out
,
w_out
,
act_param
,
ctx
);
}
else
{
conv_depthwise_3x3s1p0_bias_s
(
dout
,
...
...
@@ -119,6 +125,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in
,
h_out
,
w_out
,
act_param
,
ctx
);
}
}
...
...
@@ -136,6 +143,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in
,
h_out
,
w_out
,
act_param
,
ctx
);
}
else
{
conv_depthwise_3x3s1p1_bias_s
(
dout
,
...
...
@@ -150,11 +158,12 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in
,
h_out
,
w_out
,
act_param
,
ctx
);
}
}
}
// clang-format on
#ifdef __aarch64__
#define INIT_S1 \
"PRFM PLDL1KEEP, [%[din_ptr0]] \n" \
...
...
@@ -255,14 +264,12 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din1_0123 * w0[1]*/
\
"fmla v13.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din1_0123 * w1[1]*/
\
\
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n"
/* v16 = 00123*/
\
"ext v17.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234 */
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n"
/* v16 = 00123*/
\
"ext v17.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234 */
/* r4 */
\
"fmla v15.4s , v8.4s, %[w1].s[1]\n"
/* outr00 += din2_0123 * w1[1]*/
\
"fmla v14.4s , v8.4s, %[w2].s[1]\n"
/* outr00 += din2_0123 * w2[1]*/
#define LEFT_RESULT_S1 \
/* r4 */
\
"fmla v15.4s , v8.4s, %[w1].s[1]\n"
/* outr00 += din2_0123 * w1[1]*/
\
"fmla v14.4s , v8.4s, %[w2].s[1]\n"
/* outr00 += din2_0123 * w2[1]*/
\
\
"st1 {v12.4s}, [%[doutr0]], #16 \n"
/* vst1q_f32() */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* vst1q_f32() */
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
...
...
@@ -345,16 +352,15 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v13.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v12.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v6.16b, v7.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v6.16b, v7.16b, #8 \n"
/* v16 = 2345 */
"ext v16.16b, v6.16b, v7.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v6.16b, v7.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n"
/*vld1q_f32(din_ptr0)*/
#define MID_RESULT_S1 \
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -411,30 +417,31 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define RIGHT_COMPUTE_S1 \
"3: \n" \
"movi v20.4s, #0 \n" \
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \
"ld1 {v22.4s}, [%[doutr0]] \n" \
"ld1 {v23.4s}, [%[doutr1]] \n" \
"ld1 {v24.4s}, [%[doutr2]] \n" \
"ld1 {v25.4s}, [%[doutr3]] \n" \
\
"bif v0.16b,
%[vzero].16b, v18.16b \n"
\
"bif v1.16b,
%[vzero].16b, v19.16b \n"
\
"bif v2.16b,
%[vzero].16b, v18.16b \n"
\
"bif v3.16b,
%[vzero].16b, v19.16b \n"
\
"bif v0.16b,
v20.16b, v18.16b \n"
\
"bif v1.16b,
v20.16b, v19.16b \n"
\
"bif v2.16b,
v20.16b, v18.16b \n"
\
"bif v3.16b,
v20.16b, v19.16b \n"
\
\
"bif v4.16b,
%[vzero].16b, v18.16b \n"
\
"bif v5.16b,
%[vzero].16b, v19.16b \n"
\
"bif v6.16b,
%[vzero].16b, v18.16b \n"
\
"bif v7.16b,
%[vzero].16b, v19.16b \n"
\
"bif v4.16b,
v20.16b, v18.16b \n"
\
"bif v5.16b,
v20.16b, v19.16b \n"
\
"bif v6.16b,
v20.16b, v18.16b \n"
\
"bif v7.16b,
v20.16b, v19.16b \n"
\
\
"ext v16.16b, v0.16b, v1.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v0.16b, v1.16b, #8 \n"
/* v16 = 2345 */
/* r0 */
\
"fmla v12.4s, v0.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"bif v8.16b,
%[vzero].16b, v18.16b \n"
\
"bif v9.16b,
%[vzero].16b, v19.16b \n"
\
"bif v10.16b,
%[vzero].16b, v18.16b \n"
\
"bif v11.16b,
%[vzero].16b, v19.16b \n"
\
"bif v8.16b,
v20.16b, v18.16b \n"
\
"bif v9.16b,
v20.16b, v19.16b \n"
\
"bif v10.16b,
v20.16b, v18.16b \n"
\
"bif v11.16b,
v20.16b, v19.16b \n"
\
\
"fmla v12.4s, v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
...
...
@@ -467,15 +474,13 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v13.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v12.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v6.16b, v7.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v6.16b, v7.16b, #8 \n"
/* v16 = 2345 */
"ext v16.16b, v6.16b, v7.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v6.16b, v7.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
#define RIGHT_RESULT_S1 \
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"bif v12.16b, v22.16b, v18.16b \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -520,10 +525,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define LEFT_RESULT_S1_RELU \
/* r4 */
\
"fmla v15.4s , v8.4s, %[w1].s[1]\n"
/* outr00 += din2_0123 * w1[1]*/
\
"fmla v14.4s , v8.4s, %[w2].s[1]\n"
/* outr00 += din2_0123 * w2[1]*/
\
\
"fmax v12.4s, v12.4s, %[vzero].4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s, %[vzero].4s \n"
/*relu*/
\
\
...
...
@@ -570,14 +571,113 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"blt 3f \n"
#define LEFT_RESULT_S1_RELU6 \
"fmax v12.4s, v12.4s, %[vzero].4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s, %[vzero].4s \n"
/*relu*/
\
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n"
/* outr00 += din2_0123 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[0]\n"
/* outr00 += din2_0123 * w1[1]*/
\
\
"fmin v12.4s, v12.4s, %[vsix].4s \n"
/*relu6*/
\
"fmin v13.4s, v13.4s, %[vsix].4s \n"
/*relu6*/
\
\
"ld1 {v9.4s}, [%[din_ptr4]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din1_0123 * w0[1]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din1_0123 * w1[1]*/
\
\
"st1 {v12.4s}, [%[doutr0]], #16 \n"
/* vst1q_f32() */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* vst1q_f32() */
\
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n"
/* v16 = 00123*/
\
"ext v17.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234 */
\
"fmla v15.4s , v10.4s, %[w2].s[1]\n"
/* outr00 += din2_0123 * w1[1]*/
\
"ld1 {v12.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"ld1 {v13.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
/* r5*/
\
\
"fmax v14.4s, v14.4s, %[vzero].4s \n"
/*relu*/
\
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n"
/* outr00 += din2_0123 * w0[1]*/
\
\
"fmin v14.4s, v14.4s, %[vsix].4s \n"
/*relu6*/
\
\
"ld1 {v11.4s}, [%[din_ptr5]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din1_0123 * w0[1]*/
\
\
"st1 {v14.4s}, [%[doutr2]], #16 \n"
/* vst1q_f32() */
\
\
"ext v16.16b, v0.16b, v1.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v0.16b, v1.16b, #8 \n"
/* v16 = 2345 */
\
\
"fmax v15.4s, v15.4s, %[vzero].4s \n"
/*relu*/
\
"ld1 {v14.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
\
"fmin v15.4s, v15.4s, %[vsix].4s \n"
/*relu6*/
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
/* vst1q_f32() */
\
"cmp %w[cnt], #1 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"blt 3f \n"
#define LEFT_RESULT_S1_LEAKY_RELU \
"cmhs v18.4s, v12.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"cmhs v19.4s, v13.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v12.4s, %[vscale].4s \n"
/* mul */
\
"fmul v21.4s, v12.4s, %[vscale].4s \n"
/* mul */
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n"
/* outr00 += din2_0123 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[0]\n"
/* outr00 += din2_0123 * w1[1]*/
\
\
"bif v12.16b, v20.16b, v18.16b \n"
/* choose*/
\
"bif v13.16b, v21.16b, v19.16b \n"
/* choose*/
\
"ld1 {v9.4s}, [%[din_ptr4]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din1_0123 * w0[1]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din1_0123 * w1[1]*/
\
\
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n"
/* v16 = 00123*/
\
"ext v17.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234 */
\
"st1 {v12.4s}, [%[doutr0]], #16 \n"
/* vst1q_f32() */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* vst1q_f32() */
\
\
"fmla v15.4s , v10.4s, %[w2].s[1]\n"
/* outr00 += din2_0123 * w1[1]*/
\
\
"ld1 {v12.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"ld1 {v13.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
/* r5*/
\
"cmhs v18.4s, v14.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v14.4s, %[vscale].4s \n"
/* mul */
\
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n"
/* outr00 += din2_0123 * w0[1]*/
\
\
"bif v14.16b, v20.16b, v18.16b \n"
/* choose*/
\
\
"ld1 {v11.4s}, [%[din_ptr5]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din1_0123 * w0[1]*/
\
\
"st1 {v14.4s}, [%[doutr2]], #16 \n"
/* vst1q_f32() */
\
\
"ext v16.16b, v0.16b, v1.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v0.16b, v1.16b, #8 \n"
/* v16 = 2345 */
\
\
"cmhs v18.4s, v15.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v15.4s, %[vscale].4s \n"
/* mul */
\
"ld1 {v14.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"bif v15.16b, v20.16b, v18.16b \n"
/* choose*/
\
"cmp %w[cnt], #1 \n" \
"st1 {v15.4s}, [%[doutr3]], #16 \n"
/* vst1q_f32() */
\
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"blt 3f \n"
#define MID_RESULT_S1_RELU \
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"fmax v12.4s, v12.4s, %[vzero].4s \n"
/*relu*/
\
"movi v20.4s, #0 \n" \
"fmax v12.4s, v12.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -598,7 +698,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"fmax v13.4s, v13.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s,
v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -617,7 +717,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
/* r3 */
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"fmax v14.4s, v14.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v14.4s, v14.4s,
v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
...
...
@@ -633,20 +733,157 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v15.4s, v15.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v15.4s, v15.4s,
v20.4s \n"
/*relu*/
\
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
\
"bne 1b \n"
#define RIGHT_RESULT_S1_RELU \
#define MID_RESULT_S1_RELU6 \
"movi v20.4s, #0 \n" \
"fmax v12.4s, v12.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v13.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"fmin v12.4s, v12.4s, %[vsix].4s \n"
/*relu6*/
\
\
"ld1 {v7.4s}, [%[din_ptr3]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v13.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"ext v16.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v8.16b, v9.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"fmla v15.4s , v8.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v12.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"fmax v13.4s, v13.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"fmin v13.4s, v13.4s, %[vsix].4s \n"
/*relu6*/
\
\
"ld1 {v9.4s}, [%[din_ptr4]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v10.16b, v11.16b, #8 \n"
/* v16 = 2345 */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
/* r3 */
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"ld1 {v13.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"fmax v14.4s, v14.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"fmin v14.4s, v14.4s, %[vsix].4s \n"
/*relu6*/
\
\
"ld1 {v11.4s}, [%[din_ptr5]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v0.16b, v1.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v0.16b, v1.16b, #8 \n"
/* v16 = 2345 */
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmax v15.4s, v15.4s, v20.4s \n"
/*relu*/
\
"ld1 {v14.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
\
"fmin v15.4s, v15.4s, %[vsix].4s \n"
/*relu6*/
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
\
"bne 1b \n"
#define MID_RESULT_S1_LEAKY_RELU \
"movi v21.4s, #0 \n" \
"cmhs v18.4s, v12.4s, v21.4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v12.4s, %[vscale].4s \n"
/* mul */
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v13.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"bif v12.16b, v20.16b, v18.16b \n"
/* choose*/
\
\
"ld1 {v7.4s}, [%[din_ptr3]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v13.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v8.16b, v9.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"cmhs v18.4s, v13.4s, v21.4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v13.4s, %[vscale].4s \n"
/* mul */
\
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"ld1 {v12.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"bif v13.16b, v20.16b, v18.16b \n"
/* choose*/
\
\
"ld1 {v9.4s}, [%[din_ptr4]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v10.16b, v11.16b, #8 \n"
/* v16 = 2345 */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
/* r3 */
\
"fmla v15.4s , v6.4s, %[w0].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v6.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v13.4s , v6.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n"
/*vld1q_f32(din_ptr0)*/
\
"ld1 {v13.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"cmhs v18.4s, v14.4s, v21.4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v14.4s, %[vscale].4s \n"
/* mul */
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"bif v14.16b, v20.16b, v18.16b \n"
/* choose*/
\
\
"ld1 {v11.4s}, [%[din_ptr5]] \n"
/*vld1q_f32(din_ptr0)*/
\
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v0.16b, v1.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v0.16b, v1.16b, #8 \n"
/* v16 = 2345 */
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"cmhs v18.4s, v15.4s, v21.4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v15.4s, %[vscale].4s \n"
/* mul */
\
\
"ld1 {v14.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
"bif v15.16b, v20.16b, v18.16b \n"
/* choose*/
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n"
/*vdupq_n_f32(bias_val)*/
\
\
"fmax v12.4s, v12.4s, %[vzero].4s \n"
/*relu*/
\
"bne 1b \n"
#define RIGHT_RESULT_S1_RELU \
"fmax v12.4s, v12.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -664,7 +901,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"fmax v13.4s, v13.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s,
v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
...
...
@@ -680,7 +917,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* r3 */
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"fmax v14.4s, v14.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v14.4s, v14.4s,
v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
...
...
@@ -690,72 +927,184 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmax v15.4s, v15.4s,
%[vzero].4s \n"
/*relu*/
\
"fmax v15.4s, v15.4s,
v20.4s \n"
/*relu*/
\
\
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define COMPUTE_S_S1 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s}, [%[din0]], #16\n" \
"ld1 {v1.4s}, [%[din1]], #16\n" \
"ld1 {v2.4s}, [%[din2]], #16\n" \
"ld1 {v3.4s}, [%[din3]], #16\n" \
\
"bif v0.16b, %[zero].16b, %[mask].16b\n" \
"bif v1.16b, %[zero].16b, %[mask].16b\n" \
"bif v2.16b, %[zero].16b, %[mask].16b\n" \
"bif v3.16b, %[zero].16b, %[mask].16b\n" \
\
"ext v4.16b, %[zero].16b, v0.16b, #12\n" \
"ext v5.16b, %[zero].16b, v1.16b, #12\n" \
"ext v6.16b, %[zero].16b, v2.16b, #12\n" \
"ext v7.16b, %[zero].16b, v3.16b, #12\n" \
\
"ext v8.16b, v0.16b, %[zero].16b, #4\n" \
"ext v9.16b, v1.16b, %[zero].16b, #4\n" \
"ext v10.16b, v2.16b, %[zero].16b, #4\n" \
"ext v11.16b, v3.16b, %[zero].16b, #4\n" \
\
"fmul v12.4s, v0.4s, %[wr0].s[1]\n" \
"fmul v13.4s, v1.4s, %[wr0].s[1]\n" \
\
"fmul v14.4s, v1.4s, %[wr1].s[1]\n" \
"fmul v15.4s, v2.4s, %[wr1].s[1]\n" \
\
"fmul v16.4s, v2.4s, %[wr2].s[1]\n" \
"fmul v17.4s, v3.4s, %[wr2].s[1]\n" \
\
"fmla v12.4s, v4.4s, %[wr0].s[0]\n" \
"fmla v13.4s, v5.4s, %[wr0].s[0]\n" \
\
"fmla v14.4s, v5.4s, %[wr1].s[0]\n" \
"fmla v15.4s, v6.4s, %[wr1].s[0]\n" \
\
"fmla v16.4s, v6.4s, %[wr2].s[0]\n" \
"fmla v17.4s, v7.4s, %[wr2].s[0]\n" \
\
"fmla v12.4s, v8.4s, %[wr0].s[2]\n" \
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
\
"fmla v14.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v15.4s, v10.4s, %[wr1].s[2]\n" \
\
"fmla v16.4s, v10.4s, %[wr2].s[2]\n" \
"fmla v17.4s, v11.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v14.4s\n" \
"fadd v12.4s, v12.4s, v16.4s\n" \
\
"fadd v13.4s, v13.4s, v15.4s\n" \
"fadd v13.4s, v13.4s, v17.4s\n" \
\
"fadd v12.4s, v12.4s, %[bias].4s\n" \
#define RIGHT_RESULT_S1_RELU6 \
"fmax v12.4s, v12.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v13.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"fmin v12.4s, v12.4s, %[vsix].4s \n"
/*relu6*/
\
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v13.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v8.16b, v9.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"bif v12.16b, v22.16b, v18.16b \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmax v13.4s, v13.4s, v20.4s \n"
/*relu*/
\
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmin v13.4s, v13.4s, %[vsix].4s \n"
/*relu6*/
\
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v10.16b, v11.16b, #8 \n"
/* v16 = 2345 */
\
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v10.4s, v20.s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"fmax v14.4s, v14.4s, v20.4s \n"
/*relu*/
\
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* r3 */
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"fmin v14.4s, v14.4s, %[vsix].4s \n"
/*relu6*/
\
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"bif v14.16b, v24.16b, v18.16b \n" \
"fmax v15.4s, v15.4s, v20.4s \n"
/*relu*/
\
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmin v15.4s, v15.4s, %[vsix].4s \n"
/*relu6*/
\
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define RIGHT_RESULT_S1_LEAKY_RELU \
"movi v1.4s, #0 \n" \
"cmhs v20.4s, v12.4s, v1.4s \n"
/* vcgeq_u32 */
\
"fmul v21.4s, v12.4s, %[vscale].4s \n"
/* mul */
\
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v13.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"bif v12.16b, v21.16b, v20.16b \n"
/* choose*/
\
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v13.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v8.16b, v9.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v8.16b, v9.16b, #8 \n"
/* v16 = 2345 */
/* r3 */
\
"bif v12.16b, v22.16b, v18.16b \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
"fmla v14.4s , v8.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"cmhs v20.4s, v13.4s, v1.4s \n"
/* vcgeq_u32 */
\
"fmul v21.4s, v13.4s, %[vscale].4s \n"
/* mul */
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
"fmla v14.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"bif v13.16b, v21.16b, v20.16b \n" \
"fmla v15.4s , v17.4s, %[w1].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
"fmla v14.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"ext v16.16b, v10.16b, v11.16b, #4 \n"
/* v16 = 1234*/
\
"ext v17.16b, v10.16b, v11.16b, #8 \n"
/* v16 = 2345 */
\
\
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n"
/* outr00 += din0_0123 * w0[0]*/
\
\
"cmhs v20.4s, v14.4s, v1.4s \n"
/* vcgeq_u32 */
\
"fmul v21.4s, v14.4s, %[vscale].4s \n"
/* mul */
\
"st1 {v13.4s}, [%[doutr1]], #16 \n"
/* r3 */
\
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n"
/* outr00 += din0_1234 * w0[1]*/
\
\
"bif v14.16b, v21.16b, v20.16b \n" \
"fmla v15.4s , v17.4s, %[w2].s[2]\n"
/* outr00 += din0_2345 * w0[2]*/
\
\
"bif v14.16b, v24.16b, v18.16b \n" \
\
"cmhs v20.4s, v15.4s, v1.4s \n"
/* vcgeq_u32 */
\
"fmul v21.4s, v15.4s, %[vscale].4s \n"
/* mul */
\
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
"bif v15.16b, v21.16b, v20.16b \n" \
"bif v15.16b, v25.16b, v18.16b \n" \
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define COMPUTE_S_S1 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s}, [%[din0]], #16\n" \
"ld1 {v1.4s}, [%[din1]], #16\n" \
"ld1 {v2.4s}, [%[din2]], #16\n" \
"ld1 {v3.4s}, [%[din3]], #16\n" \
\
"bif v0.16b, %[vzero].16b, %[mask].16b\n" \
"bif v1.16b, %[vzero].16b, %[mask].16b\n" \
"bif v2.16b, %[vzero].16b, %[mask].16b\n" \
"bif v3.16b, %[vzero].16b, %[mask].16b\n" \
\
"ext v4.16b, %[vzero].16b, v0.16b, #12\n" \
"ext v5.16b, %[vzero].16b, v1.16b, #12\n" \
"ext v6.16b, %[vzero].16b, v2.16b, #12\n" \
"ext v7.16b, %[vzero].16b, v3.16b, #12\n" \
\
"ext v8.16b, v0.16b, %[vzero].16b, #4\n" \
"ext v9.16b, v1.16b, %[vzero].16b, #4\n" \
"ext v10.16b, v2.16b, %[vzero].16b, #4\n" \
"ext v11.16b, v3.16b, %[vzero].16b, #4\n" \
\
"fmul v12.4s, v0.4s, %[wr0].s[1]\n" \
"fmul v13.4s, v1.4s, %[wr0].s[1]\n" \
\
"fmul v14.4s, v1.4s, %[wr1].s[1]\n" \
"fmul v15.4s, v2.4s, %[wr1].s[1]\n" \
\
"fmul v16.4s, v2.4s, %[wr2].s[1]\n" \
"fmul v17.4s, v3.4s, %[wr2].s[1]\n" \
\
"fmla v12.4s, v4.4s, %[wr0].s[0]\n" \
"fmla v13.4s, v5.4s, %[wr0].s[0]\n" \
\
"fmla v14.4s, v5.4s, %[wr1].s[0]\n" \
"fmla v15.4s, v6.4s, %[wr1].s[0]\n" \
\
"fmla v16.4s, v6.4s, %[wr2].s[0]\n" \
"fmla v17.4s, v7.4s, %[wr2].s[0]\n" \
\
"fmla v12.4s, v8.4s, %[wr0].s[2]\n" \
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
\
"fmla v14.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v15.4s, v10.4s, %[wr1].s[2]\n" \
\
"fmla v16.4s, v10.4s, %[wr2].s[2]\n" \
"fmla v17.4s, v11.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v14.4s\n" \
"fadd v12.4s, v12.4s, v16.4s\n" \
\
"fadd v13.4s, v13.4s, v15.4s\n" \
"fadd v13.4s, v13.4s, v17.4s\n" \
\
"fadd v12.4s, v12.4s, %[bias].4s\n" \
"fadd v13.4s, v13.4s, %[bias].4s\n"
#define RESULT_S_S1 \
...
...
@@ -765,16 +1114,42 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[zero].4s\n" \
"fmax v13.4s, v13.4s, %[zero].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
#define RESULT_S_S1_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[vzero].4s\n" \
"fmax v13.4s, v13.4s, %[vzero].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_RELU6 \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[vzero].4s\n" \
"fmax v13.4s, v13.4s, %[vzero].4s\n" \
\
"fmin v12.4s, v12.4s, %[vsix].4s\n" \
"fmin v13.4s, v13.4s, %[vsix].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_LEAKY_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"cmhs v18.4s, v12.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"cmhs v19.4s, v13.4s, %[vzero].4s \n"
/* vcgeq_u32 */
\
"fmul v20.4s, v12.4s, %[vscale].4s \n"
/* mul */
\
"fmul v21.4s, v13.4s, %[vscale].4s \n"
/* mul */
\
\
"bif v12.16b, v20.16b, v18.16b \n" \
"bif v13.16b, v21.16b, v19.16b \n" \
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
...
...
@@ -786,17 +1161,17 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ld1 {v4.4s, v5.4s}, [%[din2]]\n" \
"ld1 {v6.4s, v7.4s}, [%[din3]]\n" \
\
"bif v0.16b, %[
zero].16b, %[mask1].16b\n"
\
"bif v1.16b, %[
zero].16b, %[mask2].16b\n"
\
"bif v0.16b, %[
vzero].16b, %[mask1].16b\n"
\
"bif v1.16b, %[
vzero].16b, %[mask2].16b\n"
\
\
"bif v2.16b, %[
zero].16b, %[mask1].16b\n"
\
"bif v3.16b, %[
zero].16b, %[mask2].16b\n"
\
"bif v2.16b, %[
vzero].16b, %[mask1].16b\n"
\
"bif v3.16b, %[
vzero].16b, %[mask2].16b\n"
\
\
"bif v4.16b, %[
zero].16b, %[mask1].16b\n"
\
"bif v5.16b, %[
zero].16b, %[mask2].16b\n"
\
"bif v4.16b, %[
vzero].16b, %[mask1].16b\n"
\
"bif v5.16b, %[
vzero].16b, %[mask2].16b\n"
\
\
"bif v6.16b, %[
zero].16b, %[mask1].16b\n"
\
"bif v7.16b, %[
zero].16b, %[mask2].16b\n"
\
"bif v6.16b, %[
vzero].16b, %[mask1].16b\n"
\
"bif v7.16b, %[
vzero].16b, %[mask2].16b\n"
\
\
"ext v8.16b, v0.16b, v1.16b, #4\n" \
"ext v9.16b, v0.16b, v1.16b, #8\n" \
...
...
@@ -849,7 +1224,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
// "st1 {v12.4s}, [%[out1]]\n" \
// "st1 {v13.4s}, [%[out2]]\n" \
#else
#define INIT_S1 \
"pld [%[din0_ptr]] @ preload data\n" \
...
...
@@ -1129,6 +1503,66 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define LEFT_RESULT_S1_RELU6 \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vdup.32 q4, %[bias_val] @ and \n" \
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define LEFT_RESULT_S1_LEAKY_RELU \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
"vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vbif q4, q6, q15 @ choose \n" \
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q5, q14 \n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vbif q5, q6, q7 @ choose \n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define MID_RESULT_S1_RELU \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
...
...
@@ -1157,6 +1591,69 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"bne 1b @ jump to main loop start point\n"
#define MID_RESULT_S1_RELU6 \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[six_ptr]]! @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define MID_RESULT_S1_LEAKY_RELU \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
\
"vbif q4, q6, q15 @ choose \n" \
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
\
"vbif q5, q6, q7 @ choose \n" \
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define RIGHT_RESULT_S1_RELU \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
...
...
@@ -1178,6 +1675,58 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define RIGHT_RESULT_S1_RELU6 \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define RIGHT_RESULT_S1_LEAKY_RELU \
/* r3 */
\
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vbif q4, q6, q15 @ choose \n" \
\
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q5, q14 \n" \
\
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
"vbif q5, q6, q7 @ choose \n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define COMPUTE_S_S1 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
...
...
@@ -1251,6 +1800,36 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define RESULT_S_S1_RELU6 \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vld1.32 {d20-d21}, [%[six_ptr]] \n" \
"vmax.f32 q14, q14, %q[vzero]\n" \
"vmax.f32 q15, q15, %q[vzero]\n" \
\
"vmin.f32 q14, q14, q10 \n" \
"vmin.f32 q15, q15, q10 \n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define RESULT_S_S1_LEAKY_RELU \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vld1.32 {d18-d19}, [%[scale_ptr]] \n" \
"vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \
"vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q12, q14, q9 \n" \
"vmul.f32 q13, q15, q9 \n" \
\
"vbif q14, q10, q12 \n" \
"vbif q15, q11, q13 \n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
...
...
@@ -1333,6 +1912,413 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vadd.f32 q15, q5, q9 @ q4 += q10 \n"
#endif
#ifdef __aarch64__
void
act_switch_3x3s1p1
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
const
float
*
din_ptr4
,
const
float
*
din_ptr5
,
float
*
doutr0
,
float
*
doutr1
,
float
*
doutr2
,
float
*
doutr3
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
unsigned
int
*
vmask
,
unsigned
int
*
rmask
,
float32x4_t
vzero
,
float
*
vbias
,
int
cnt
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
float32x4_t
vsix
=
vdupq_n_f32
(
act_param
.
Relu_clipped_coef
);
float32x4_t
vscale
=
vdupq_n_f32
(
act_param
.
Leaky_relu_alpha
);
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU
MID_COMPUTE_S1
MID_RESULT_S1_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU6
MID_COMPUTE_S1
MID_RESULT_S1_RELU6
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
vsix
]
"w"
(
vsix
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
vscale
]
"w"
(
vscale
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1
MID_COMPUTE_S1
MID_RESULT_S1
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
}
}
#else
void
act_switch_3x3s1p1
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
float
*
doutr0
,
float
*
doutr1
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
unsigned
int
*
vmask_ptr
,
unsigned
int
*
rmask_ptr
,
float32x4_t
vzero
,
float
bias_val
,
int
cnt
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
float
tmp
=
act_param
.
Relu_clipped_coef
;
float
ss
=
act_param
.
Leaky_relu_alpha
;
float
vsix
[
4
]
=
{
tmp
,
tmp
,
tmp
,
tmp
};
float
vscale
[
4
]
=
{
ss
,
ss
,
ss
,
ss
};
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU
MID_COMPUTE_S1
MID_RESULT_S1_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU6
MID_COMPUTE_S1
MID_RESULT_S1_RELU6
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
six_ptr
]
"r"
(
vsix
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
scale_ptr
]
"r"
(
vscale
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1
MID_COMPUTE_S1
MID_RESULT_S1
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
}
#endif
// clang-format on
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
...
...
@@ -1349,6 +2335,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
...
...
@@ -1486,106 +2473,25 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
}
int
cnt
=
cnt_col
;
if
(
flag_relu
)
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU
MID_COMPUTE_S1
MID_RESULT_S1_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
}
else
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1
MID_COMPUTE_S1
MID_RESULT_S1
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
}
act_switch_3x3s1p1
(
din_ptr0
,
din_ptr1
,
din_ptr2
,
din_ptr3
,
din_ptr4
,
din_ptr5
,
doutr0
,
doutr1
,
doutr2
,
doutr3
,
wr0
,
wr1
,
wr2
,
vmask
,
rmask
,
vzero
,
vbias
,
cnt
,
act_param
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
#else
...
...
@@ -1598,7 +2504,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
doutr0
=
dout_ptr
;
doutr1
=
dout_ptr
+
w_out
;
// unsigned int* rst_mask = rmask;
if
(
i
==
0
)
{
din_ptr0
=
zero_ptr
;
...
...
@@ -1635,77 +2540,314 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int
cnt
=
cnt_col
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
if
(
flag_relu
)
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1_RELU
MID_COMPUTE_S1
MID_RESULT_S1_RELU
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
else
{
asm
volatile
(
INIT_S1
LEFT_COMPUTE_S1
LEFT_RESULT_S1
MID_COMPUTE_S1
MID_RESULT_S1
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
act_switch_3x3s1p1
(
din_ptr0
,
din_ptr1
,
din_ptr2
,
din_ptr3
,
doutr0
,
doutr1
,
wr0
,
wr1
,
wr2
,
vmask_ptr
,
rmask_ptr
,
vzero
,
bias_val
,
cnt
,
act_param
);
dout_ptr
+=
2
*
w_out
;
}
//! end of processing mid rows
#endif
}
}
}
void
act_switch_3x3s1p1_s
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
float
*
doutr0
,
float
*
doutr1
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
uint32x4_t
vmask_rp
,
float32x4_t
vzero
,
float32x4_t
wbias
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
#ifdef __aarch64__
float32x4_t
vsix
=
vdupq_n_f32
(
act_param
.
Relu_clipped_coef
);
float32x4_t
vscale
=
vdupq_n_f32
(
act_param
.
Leaky_relu_alpha
);
#else
float
tmp
=
act_param
.
Relu_clipped_coef
;
float
ss
=
act_param
.
Leaky_relu_alpha
;
float
vsix
[
4
]
=
{
tmp
,
tmp
,
tmp
,
tmp
};
float
vscale
[
4
]
=
{
ss
,
ss
,
ss
,
ss
};
#endif
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU6
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
vsix
]
"w"
(
vsix
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU6
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
six_ptr
]
"r"
(
vsix
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_LEAKY_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
vscale
]
"w"
(
vscale
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_LEAKY_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
scale_ptr
]
"r"
(
vscale
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
#else
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
...
...
@@ -1722,6 +2864,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
...
...
@@ -1772,7 +2915,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
if
(
hs
==
-
1
)
{
dr0
=
zero
;
}
switch
(
he
-
h_in
)
{
case
2
:
dr2
=
zero
;
...
...
@@ -1782,127 +2924,19 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
default:
break
;
}
#ifdef __aarch64__
if
(
flag_relu
)
{
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
}
else
{
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
zero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
);
}
#else
if
(
flag_relu
)
{
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
else
{
asm
volatile
(
COMPUTE_S_S1
RESULT_S_S1
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
mask
]
"w"
(
vmask_rp
),
[
bias
]
"w"
(
wbias
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
#endif
act_switch_3x3s1p1_s
(
dr0
,
dr1
,
dr2
,
dr3
,
out_buf1
,
out_buf2
,
wr0
,
wr1
,
wr2
,
vmask_rp
,
vzero
,
wbias
,
act_param
);
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
...
...
@@ -1916,6 +2950,490 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
}
// end of processing batchs
}
#ifdef __aarch64__
void
act_switch_3x3s1p0
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
const
float
*
din_ptr4
,
const
float
*
din_ptr5
,
float
*
doutr0
,
float
*
doutr1
,
float
*
doutr2
,
float
*
doutr3
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
unsigned
int
*
vmask
,
unsigned
int
*
rmask
,
float32x4_t
vzero
,
float
*
vbias
,
int
cnt
,
int
remain
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
float32x4_t
vsix
=
vdupq_n_f32
(
act_param
.
Relu_clipped_coef
);
float32x4_t
vscale
=
vdupq_n_f32
(
act_param
.
Leaky_relu_alpha
);
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
asm
volatile
(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %w[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
asm
volatile
(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %w[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
vsix
]
"w"
(
vsix
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
asm
volatile
(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
vscale
]
"w"
(
vscale
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
asm
volatile
(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1
MID_RESULT_S1
"cmp %w[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
"0:
\n
"
:
[
cnt
]
"+r"
(
cnt
),
[
din_ptr0
]
"+r"
(
din_ptr0
),
[
din_ptr1
]
"+r"
(
din_ptr1
),
[
din_ptr2
]
"+r"
(
din_ptr2
),
[
din_ptr3
]
"+r"
(
din_ptr3
),
[
din_ptr4
]
"+r"
(
din_ptr4
),
[
din_ptr5
]
"+r"
(
din_ptr5
),
[
doutr0
]
"+r"
(
doutr0
),
[
doutr1
]
"+r"
(
doutr1
),
[
doutr2
]
"+r"
(
doutr2
),
[
doutr3
]
"+r"
(
doutr3
)
:
[
w0
]
"w"
(
wr0
),
[
w1
]
"w"
(
wr1
),
[
w2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
vbias
),
[
vmask
]
"r"
(
vmask
),
[
rmask
]
"r"
(
rmask
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
}
}
#else
void
act_switch_3x3s1p0
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
float
*
doutr0
,
float
*
doutr1
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
unsigned
int
*
vmask_ptr
,
unsigned
int
*
rmask_ptr
,
float32x4_t
vzero
,
float
bias_val
,
int
cnt
,
int
remain
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
float
tmp
=
act_param
.
Relu_clipped_coef
;
float
ss
=
act_param
.
Leaky_relu_alpha
;
float
vsix
[
4
]
=
{
tmp
,
tmp
,
tmp
,
tmp
};
float
vscale
[
4
]
=
{
ss
,
ss
,
ss
,
ss
};
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
asm
volatile
(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"vext.32 q6, q8, q9, #1 @ 0012
\n
"
"vext.32 q7, q8, q9, #2 @ 1234
\n
"
MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
asm
volatile
(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"vext.32 q6, q8, q9, #1 @ 0012
\n
"
"vext.32 q7, q8, q9, #2 @ 1234
\n
"
MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
six_ptr
]
"r"
(
vsix
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
asm
volatile
(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"vext.32 q6, q8, q9, #1 @ 0012
\n
"
"vext.32 q7, q8, q9, #2 @ 1234
\n
"
MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
scale_ptr
]
"r"
(
vscale
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
asm
volatile
(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap
\n
"
"vext.32 q6, q8, q9, #1 @ 0012
\n
"
"vext.32 q7, q8, q9, #2 @ 1234
\n
"
MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1
\n
"
"blt 0f
\n
"
RIGHT_COMPUTE_S1
RIGHT_RESULT_S1
"0:
\n
"
:
[
dout_ptr1
]
"+r"
(
doutr0
),
[
dout_ptr2
]
"+r"
(
doutr1
),
[
din0_ptr
]
"+r"
(
din_ptr0
),
[
din1_ptr
]
"+r"
(
din_ptr1
),
[
din2_ptr
]
"+r"
(
din_ptr2
),
[
din3_ptr
]
"+r"
(
din_ptr3
),
[
cnt
]
"+r"
(
cnt
),
[
rmask
]
"+r"
(
rmask_ptr
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
bias_val
]
"r"
(
bias_val
),
[
vzero
]
"w"
(
vzero
),
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
...
...
@@ -1932,6 +3450,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
//! pad is done implicit
const
float
zero
[
8
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
};
...
...
@@ -2060,15 +3579,16 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
}
int
cnt
=
tile_w
;
/*
if (flag_relu) {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/
*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/
* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/
* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /
/ vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /
/ vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" /
/ v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" /
/ v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" /
/ vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" /
/ vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
...
...
@@ -2123,12 +3643,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16
\n
"
/
*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4
\n
"
/
* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8
\n
"
/
* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]]
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]]
\n
"
/
*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /
/ vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /
/ vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" /
/ v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" /
/ v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" /
/ vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" /
/ vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
...
...
@@ -2181,6 +3701,27 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
"v24",
"v25");
}
*/
act_switch_3x3s1p0
(
din_ptr0
,
din_ptr1
,
din_ptr2
,
din_ptr3
,
din_ptr4
,
din_ptr5
,
doutr0
,
doutr1
,
doutr2
,
doutr3
,
wr0
,
wr1
,
wr2
,
vmask
,
rmask
,
vzero
,
vbias
,
cnt
,
remain
,
act_param
);
dout_ptr
=
dout_ptr
+
4
*
w_out
;
}
#else
...
...
@@ -2219,6 +3760,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int
cnt
=
tile_w
;
unsigned
int
*
rmask_ptr
=
rmask
;
unsigned
int
*
vmask_ptr
=
vmask
;
/*
if (flag_relu) {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
...
...
@@ -2301,13 +3843,328 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
"q13",
"q14",
"q15");
}
}*/
act_switch_3x3s1p0
(
din_ptr0
,
din_ptr1
,
din_ptr2
,
din_ptr3
,
doutr0
,
doutr1
,
wr0
,
wr1
,
wr2
,
vmask_ptr
,
rmask_ptr
,
vzero
,
bias_val
,
cnt
,
remain
,
act_param
);
dout_ptr
+=
2
*
w_out
;
}
//! end of processing mid rows
#endif
}
}
}
void
act_switch_3x3s1p0_s
(
const
float
*
din_ptr0
,
const
float
*
din_ptr1
,
const
float
*
din_ptr2
,
const
float
*
din_ptr3
,
float
*
doutr0
,
float
*
doutr1
,
float32x4_t
wr0
,
float32x4_t
wr1
,
float32x4_t
wr2
,
uint32x4_t
vmask_rp1
,
uint32x4_t
vmask_rp2
,
float32x4_t
vzero
,
float32x4_t
wbias
,
unsigned
int
*
vmask_ptr
,
float
bias_val
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
#ifdef __aarch64__
float32x4_t
vsix
=
vdupq_n_f32
(
act_param
.
Relu_clipped_coef
);
float32x4_t
vscale
=
vdupq_n_f32
(
act_param
.
Leaky_relu_alpha
);
#else
float
tmp
=
act_param
.
Relu_clipped_coef
;
float
ss
=
act_param
.
Leaky_relu_alpha
;
float
vsix
[
4
]
=
{
tmp
,
tmp
,
tmp
,
tmp
};
float
vscale
[
4
]
=
{
ss
,
ss
,
ss
,
ss
};
#endif
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
vzero
]
"w"
(
vzero
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU6
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
vzero
]
"w"
(
vzero
),
[
vsix
]
"w"
(
vsix
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU6
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
six_ptr
]
"r"
(
vsix
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_LEAKY_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
vzero
]
"w"
(
vzero
),
[
vscale
]
"w"
(
vscale
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
break
;
#else
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_LEAKY_RELU
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
scale_ptr
]
"r"
(
vscale
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
break
;
#endif
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
vzero
]
"w"
(
vzero
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
#else
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1
:
[
din0
]
"+r"
(
din_ptr0
),
[
din1
]
"+r"
(
din_ptr1
),
[
din2
]
"+r"
(
din_ptr2
),
[
din3
]
"+r"
(
din_ptr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
doutr0
),
[
out2
]
"r"
(
doutr1
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
...
...
@@ -2324,6 +4181,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const
int
w_in
,
const
int
h_out
,
const
int
w_out
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
...
...
@@ -2355,15 +4213,22 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
float32x4_t
wr1
=
vld1q_f32
(
weight_ptr
+
3
);
float32x4_t
wr2
=
vld1q_f32
(
weight_ptr
+
6
);
#ifdef __aarch64__
// #ifdef __aarch64__
// float32x4_t wbias;
// if (flag_bias) {
// wbias = vdupq_n_f32(bias[i]);
// } else {
// wbias = vdupq_n_f32(0.f);
// }
// #endif // __aarch64__
float32x4_t
wbias
;
float
bias_val
=
0.
f
;
if
(
flag_bias
)
{
wbias
=
vdupq_n_f32
(
bias
[
i
]);
bias_val
=
bias
[
i
];
}
else
{
wbias
=
vdupq_n_f32
(
0.
f
);
}
#endif // __aarch64__
float
out_buf1
[
4
];
float
out_buf2
[
4
];
float
trash_buf
[
4
];
...
...
@@ -2396,135 +4261,154 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
break
;
}
}
#ifdef __aarch64__
if
(
flag_relu
)
{
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
zero
]
"w"
(
vzero
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
}
else
{
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vbias
]
"w"
(
wbias
),
[
mask1
]
"w"
(
vmask_rp1
),
[
mask2
]
"w"
(
vmask_rp2
),
[
zero
]
"w"
(
vzero
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
);
}
#else
/*
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
*/
unsigned
int
*
vmask_ptr
=
vmask
;
float
bias_val
=
flag_bias
?
bias
[
i
]
:
0.
f
;
if
(
flag_relu
)
{
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1_RELU
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
else
{
asm
volatile
(
COMPUTE_S_S1_P0
RESULT_S_S1
:
[
din0
]
"+r"
(
dr0
),
[
din1
]
"+r"
(
dr1
),
[
din2
]
"+r"
(
dr2
),
[
din3
]
"+r"
(
dr3
),
[
vmask
]
"+r"
(
vmask_ptr
)
:
[
wr0
]
"w"
(
wr0
),
[
wr1
]
"w"
(
wr1
),
[
wr2
]
"w"
(
wr2
),
[
vzero
]
"w"
(
vzero
),
[
bias_val
]
"r"
(
bias_val
),
[
out1
]
"r"
(
out_buf1
),
[
out2
]
"r"
(
out_buf2
)
:
"cc"
,
"memory"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
#endif
act_switch_3x3s1p0_s
(
dr0
,
dr1
,
dr2
,
dr3
,
out_buf1
,
out_buf2
,
wr0
,
wr1
,
wr2
,
vmask_rp1
,
vmask_rp2
,
vzero
,
wbias
,
vmask_ptr
,
bias_val
,
act_param
);
for
(
int
w
=
0
;
w
<
w_out
;
++
w
)
{
*
doutr0
++
=
out_buf1
[
w
];
*
doutr1
++
=
out_buf2
[
w
];
...
...
lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc
浏览文件 @
3455ab0a
...
...
@@ -25,6 +25,785 @@ namespace paddle {
namespace
lite
{
namespace
arm
{
namespace
math
{
// clang-format off
#ifdef __aarch64__
#define COMPUTE \
"ldp q0, q1, [%[inr0]], #32\n"
/* load input r0*/
\
"ldp q6, q7, [%[inr1]], #32\n"
/* load input r1*/
\
"ldp q2, q3, [%[inr0]], #32\n"
/* load input r0*/
\
"ldp q8, q9, [%[inr1]], #32\n"
/* load input r1*/
\
"ldp q4, q5, [%[inr0]]\n"
/* load input r0*/
\
"ldp q10, q11, [%[inr1]]\n"
/* load input r1*/
\
/* r0, r1, mul w0, get out r0, r1 */
\
"fmul v15.4s , %[w0].4s, v0.4s\n"
/* outr00 = w0 * r0, 0*/
\
"fmul v16.4s , %[w0].4s, v1.4s\n"
/* outr01 = w0 * r0, 1*/
\
"fmul v17.4s , %[w0].4s, v2.4s\n"
/* outr02 = w0 * r0, 2*/
\
"fmul v18.4s , %[w0].4s, v3.4s\n"
/* outr03 = w0 * r0, 3*/
\
"fmul v19.4s , %[w0].4s, v6.4s\n"
/* outr10 = w0 * r1, 0*/
\
"fmul v20.4s , %[w0].4s, v7.4s\n"
/* outr11 = w0 * r1, 1*/
\
"fmul v21.4s , %[w0].4s, v8.4s\n"
/* outr12 = w0 * r1, 2*/
\
"fmul v22.4s , %[w0].4s, v9.4s\n"
/* outr13 = w0 * r1, 3*/
\
/* r0, r1, mul w1, get out r0, r1 */
\
"fmla v15.4s , %[w1].4s, v1.4s\n"
/* outr00 = w1 * r0[1]*/
\
"ldp q0, q1, [%[inr2]], #32\n"
/* load input r2*/
\
"fmla v16.4s , %[w1].4s, v2.4s\n"
/* outr01 = w1 * r0[2]*/
\
"fmla v17.4s , %[w1].4s, v3.4s\n"
/* outr02 = w1 * r0[3]*/
\
"fmla v18.4s , %[w1].4s, v4.4s\n"
/* outr03 = w1 * r0[4]*/
\
"fmla v19.4s , %[w1].4s, v7.4s\n"
/* outr10 = w1 * r1[1]*/
\
"fmla v20.4s , %[w1].4s, v8.4s\n"
/* outr11 = w1 * r1[2]*/
\
"fmla v21.4s , %[w1].4s, v9.4s\n"
/* outr12 = w1 * r1[3]*/
\
"fmla v22.4s , %[w1].4s, v10.4s\n"
/* outr13 = w1 * r1[4]*/
\
/* r0, r1, mul w2, get out r0, r1 */
\
"fmla v15.4s , %[w2].4s, v2.4s\n"
/* outr00 = w2 * r0[2]*/
\
"fmla v16.4s , %[w2].4s, v3.4s\n"
/* outr01 = w2 * r0[3]*/
\
"ldp q2, q3, [%[inr2]], #32\n"
/* load input r2*/
\
"fmla v17.4s , %[w2].4s, v4.4s\n"
/* outr02 = w2 * r0[4]*/
\
"fmla v18.4s , %[w2].4s, v5.4s\n"
/* outr03 = w2 * r0[5]*/
\
"ldp q4, q5, [%[inr2]]\n"
/* load input r2*/
\
"fmla v19.4s , %[w2].4s, v8.4s\n"
/* outr10 = w2 * r1[2]*/
\
"fmla v20.4s , %[w2].4s, v9.4s\n"
/* outr11 = w2 * r1[3]*/
\
"fmla v21.4s , %[w2].4s, v10.4s\n"
/* outr12 = w2 * r1[4]*/
\
"fmla v22.4s , %[w2].4s, v11.4s\n"
/* outr13 = w2 * r1[5]*/
\
/* r1, r2, mul w3, get out r0, r1 */
\
"fmla v15.4s , %[w3].4s, v6.4s\n"
/* outr00 = w3 * r1[0]*/
\
"fmla v16.4s , %[w3].4s, v7.4s\n"
/* outr01 = w3 * r1[1]*/
\
"fmla v17.4s , %[w3].4s, v8.4s\n"
/* outr02 = w3 * r1[2]*/
\
"fmla v18.4s , %[w3].4s, v9.4s\n"
/* outr03 = w3 * r1[3]*/
\
"fmla v19.4s , %[w3].4s, v0.4s\n"
/* outr10 = w3 * r2[0]*/
\
"fmla v20.4s , %[w3].4s, v1.4s\n"
/* outr11 = w3 * r2[1]*/
\
"fmla v21.4s , %[w3].4s, v2.4s\n"
/* outr12 = w3 * r2[2]*/
\
"fmla v22.4s , %[w3].4s, v3.4s\n"
/* outr13 = w3 * r2[3]*/
\
/* r1, r2, mul w4, get out r0, r1 */
\
"fmla v15.4s , %[w4].4s, v7.4s\n"
/* outr00 = w4 * r1[1]*/
\
"ldp q6, q7, [%[inr3]], #32\n"
/* load input r3*/
\
"fmla v16.4s , %[w4].4s, v8.4s\n"
/* outr01 = w4 * r1[2]*/
\
"fmla v17.4s , %[w4].4s, v9.4s\n"
/* outr02 = w4 * r1[3]*/
\
"fmla v18.4s , %[w4].4s, v10.4s\n"
/* outr03 = w4 * r1[4]*/
\
"ldp x0, x1, [%[outl]] \n" \
"fmla v19.4s , %[w4].4s, v1.4s\n"
/* outr10 = w4 * r2[1]*/
\
"fmla v20.4s , %[w4].4s, v2.4s\n"
/* outr11 = w4 * r2[2]*/
\
"fmla v21.4s , %[w4].4s, v3.4s\n"
/* outr12 = w4 * r2[3]*/
\
"fmla v22.4s , %[w4].4s, v4.4s\n"
/* outr13 = w4 * r2[4]*/
\
/* r1, r2, mul w5, get out r0, r1 */
\
"fmla v15.4s , %[w5].4s, v8.4s\n"
/* outr00 = w5 * r1[2]*/
\
"fmla v16.4s , %[w5].4s, v9.4s\n"
/* outr01 = w5 * r1[3]*/
\
"ldp q8, q9, [%[inr3]], #32\n"
/* load input r3*/
\
"fmla v17.4s , %[w5].4s, v10.4s\n"
/* outr02 = w5 * r1[4]*/
\
"fmla v18.4s , %[w5].4s, v11.4s\n"
/* outr03 = w5 * r1[5]*/
\
"ldp q10, q11, [%[inr3]]\n"
/* load input r3*/
\
"fmla v19.4s , %[w5].4s, v2.4s\n"
/* outr10 = w5 * r2[2]*/
\
"fmla v20.4s , %[w5].4s, v3.4s\n"
/* outr11 = w5 * r2[3]*/
\
"fmla v21.4s , %[w5].4s, v4.4s\n"
/* outr12 = w5 * r2[4]*/
\
"fmla v22.4s , %[w5].4s, v5.4s\n"
/* outr13 = w5 * r2[5]*/
\
/* r2, r3, mul w6, get out r0, r1 */
\
"fmla v15.4s , %[w6].4s, v0.4s\n"
/* outr00 = w6 * r2[0]*/
\
"fmla v16.4s , %[w6].4s, v1.4s\n"
/* outr01 = w6 * r2[1]*/
\
"fmla v17.4s , %[w6].4s, v2.4s\n"
/* outr02 = w6 * r2[2]*/
\
"fmla v18.4s , %[w6].4s, v3.4s\n"
/* outr03 = w6 * r2[3]*/
\
"ldp x2, x3, [%[outl], #16] \n" \
"fmla v19.4s , %[w6].4s, v6.4s\n"
/* outr10 = w6 * r3[0]*/
\
"fmla v20.4s , %[w6].4s, v7.4s\n"
/* outr11 = w6 * r3[1]*/
\
"fmla v21.4s , %[w6].4s, v8.4s\n"
/* outr12 = w6 * r3[2]*/
\
"fmla v22.4s , %[w6].4s, v9.4s\n"
/* outr13 = w6 * r3[3]*/
\
/* r2, r3, mul w7, get out r0, r1 */
\
"fmla v15.4s , %[w7].4s, v1.4s\n"
/* outr00 = w7 * r2[1]*/
\
"fmla v16.4s , %[w7].4s, v2.4s\n"
/* outr01 = w7 * r2[2]*/
\
"fmla v17.4s , %[w7].4s, v3.4s\n"
/* outr02 = w7 * r2[3]*/
\
"fmla v18.4s , %[w7].4s, v4.4s\n"
/* outr03 = w7 * r2[4]*/
\
"ldp x4, x5, [%[outl], #32] \n" \
"fmla v19.4s , %[w7].4s, v7.4s\n"
/* outr10 = w7 * r3[1]*/
\
"fmla v20.4s , %[w7].4s, v8.4s\n"
/* outr11 = w7 * r3[2]*/
\
"fmla v21.4s , %[w7].4s, v9.4s\n"
/* outr12 = w7 * r3[3]*/
\
"fmla v22.4s , %[w7].4s, v10.4s\n"
/* outr13 = w7 * r3[4]*/
\
/* r2, r3, mul w8, get out r0, r1 */
\
"fmla v15.4s , %[w8].4s, v2.4s\n"
/* outr00 = w8 * r2[2]*/
\
"fmla v16.4s , %[w8].4s, v3.4s\n"
/* outr01 = w8 * r2[3]*/
\
"fmla v17.4s , %[w8].4s, v4.4s\n"
/* outr02 = w8 * r2[0]*/
\
"fmla v18.4s , %[w8].4s, v5.4s\n"
/* outr03 = w8 * r2[1]*/
\
"ldp x6, x7, [%[outl], #48] \n" \
"fmla v19.4s , %[w8].4s, v8.4s\n"
/* outr10 = w8 * r3[2]*/
\
"fmla v20.4s , %[w8].4s, v9.4s\n"
/* outr11 = w8 * r3[3]*/
\
"fmla v21.4s , %[w8].4s, v10.4s\n"
/* outr12 = w8 * r3[0]*/
\
"fmla v22.4s , %[w8].4s, v11.4s\n"
/* outr13 = w8 * r3[1]*/
\
\
"fadd v15.4s, v15.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v16.4s, v16.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v17.4s, v17.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v18.4s, v18.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v19.4s, v19.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v20.4s, v20.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v21.4s, v21.4s, %[vbias].4s\n"
/* add bias */
\
"fadd v22.4s, v22.4s, %[vbias].4s\n"
/* add bias */
\
/* transpose */
\
"trn1 v0.4s, v15.4s, v16.4s\n"
/* r0: a0a1c0c1*/
\
"trn2 v1.4s, v15.4s, v16.4s\n"
/* r0: b0b1d0d1*/
\
"trn1 v2.4s, v17.4s, v18.4s\n"
/* r0: a2a3c2c3*/
\
"trn2 v3.4s, v17.4s, v18.4s\n"
/* r0: b2b3d2d3*/
\
"trn1 v4.4s, v19.4s, v20.4s\n"
/* r1: a0a1c0c1*/
\
"trn2 v5.4s, v19.4s, v20.4s\n"
/* r1: b0b1d0d1*/
\
"trn1 v6.4s, v21.4s, v22.4s\n"
/* r1: a2a3c2c3*/
\
"trn2 v7.4s, v21.4s, v22.4s\n"
/* r1: b2b3d2d3*/
\
"trn1 v15.2d, v0.2d, v2.2d\n"
/* r0: a0a1a2a3*/
\
"trn2 v19.2d, v0.2d, v2.2d\n"
/* r0: c0c1c2c3*/
\
"trn1 v17.2d, v1.2d, v3.2d\n"
/* r0: b0b1b2b3*/
\
"trn2 v21.2d, v1.2d, v3.2d\n"
/* r0: d0d1d2d3*/
\
"trn1 v16.2d, v4.2d, v6.2d\n"
/* r1: a0a1a2a3*/
\
"trn2 v20.2d, v4.2d, v6.2d\n"
/* r1: c0c1c2c3*/
\
"trn1 v18.2d, v5.2d, v7.2d\n"
/* r1: b0b1b2b3*/
\
"trn2 v22.2d, v5.2d, v7.2d\n"
/* r1: d0d1d2d3*/
#define RELU \
"movi v0.4s, #0\n"
/* for relu */
\
"ldr x0, [%[outl], #80]\n" \
"fmax v15.4s, v15.4s, v0.4s\n" \
"fmax v16.4s, v16.4s, v0.4s\n" \
"fmax v17.4s, v17.4s, v0.4s\n" \
"fmax v18.4s, v18.4s, v0.4s\n" \
"ld1 {v1.4s}, [x0]\n" \
"fmax v19.4s, v19.4s, v0.4s\n" \
"fmax v20.4s, v20.4s, v0.4s\n" \
"fmax v21.4s, v21.4s, v0.4s\n" \
"fmax v22.4s, v22.4s, v0.4s\n" \
"ldr x0, [%[outl]]\n" \
#define RELU6 \
"fmin v15.4s, v15.4s, v1.4s\n" \
"fmin v16.4s, v16.4s, v1.4s\n" \
"fmin v17.4s, v17.4s, v1.4s\n" \
"fmin v18.4s, v18.4s, v1.4s\n" \
"fmin v19.4s, v19.4s, v1.4s\n" \
"fmin v20.4s, v20.4s, v1.4s\n" \
"fmin v21.4s, v21.4s, v1.4s\n" \
"fmin v22.4s, v22.4s, v1.4s\n"
#define LEAKY_RELU \
"movi v0.4s, #0\n"
/* for relu */
\
"ldr x0, [%[outl], #88]\n" \
"cmhs v1.4s, v15.4s, v0.4s \n"
/* vcgeq_u32 */
\
"cmhs v2.4s, v16.4s, v0.4s \n"
/* vcgeq_u32 */
\
"ld1 {v9.4s}, [x0] \n" \
"cmhs v3.4s, v17.4s, v0.4s \n"
/* vcgeq_u32 */
\
"cmhs v4.4s, v18.4s, v0.4s \n"
/* vcgeq_u32 */
\
"ldr x0, [%[outl]] \n" \
"fmul v5.4s, v15.4s, v9.4s \n"
/* mul */
\
"fmul v6.4s, v16.4s, v9.4s \n"
/* mul */
\
"fmul v7.4s, v17.4s, v9.4s \n"
/* mul */
\
"fmul v8.4s, v18.4s, v9.4s \n"
/* mul */
\
"bif v15.16b, v5.16b, v1.16b \n"
/* choose*/
\
"bif v16.16b, v6.16b, v2.16b \n"
/* choose*/
\
"bif v17.16b, v7.16b, v3.16b \n"
/* choose*/
\
"bif v18.16b, v8.16b, v4.16b \n"
/* choose*/
\
"cmhs v1.4s, v19.4s, v0.4s \n"
/* vcgeq_u32 */
\
"cmhs v2.4s, v20.4s, v0.4s \n"
/* vcgeq_u32 */
\
"cmhs v3.4s, v21.4s, v0.4s \n"
/* vcgeq_u32 */
\
"cmhs v4.4s, v22.4s, v0.4s \n"
/* vcgeq_u32 */
\
"fmul v5.4s, v19.4s, v9.4s \n"
/* mul */
\
"fmul v6.4s, v20.4s, v9.4s \n"
/* mul */
\
"fmul v7.4s, v21.4s, v9.4s \n"
/* mul */
\
"fmul v8.4s, v22.4s, v9.4s \n"
/* mul */
\
"bif v19.16b, v5.16b, v1.16b \n"
/* choose*/
\
"bif v20.16b, v6.16b, v2.16b \n"
/* choose*/
\
"bif v21.16b, v7.16b, v3.16b \n"
/* choose*/
\
"bif v22.16b, v8.16b, v4.16b \n"
/* choose*/
#define STORE \
"cbnz %w[flag_mask], 1f\n" \
"str q15, [x0]\n"
/* save outc00 */
\
"str q16, [x4]\n"
/* save outc01 */
\
"str q17, [x1]\n"
/* save outc10 */
\
"str q18, [x5]\n"
/* save outc11 */
\
"str q19, [x2]\n"
/* save outc20 */
\
"str q20, [x6]\n"
/* save outc21 */
\
"str q21, [x3]\n"
/* save outc30 */
\
"str q22, [x7]\n"
/* save outc31 */
\
"b 2f\n" \
"1:\n" \
"str q15, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q17, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q19, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q21, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q16, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q18, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q20, [%[out]], #16 \n"
/* save remain to pre_out */
\
"str q22, [%[out]], #16 \n"
/* save remain to pre_out */
\
"2:\n"
#else
#define COMPUTE \
/* load weights */
\
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" \
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" \
/* load r0, r1 */
\
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" \
"vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" \
/* main loop */
\
"0: @ main loop\n" \
/* mul r0 with w0, w1, w2, get out r0 */
\
"vmul.f32 q8, q5, q0 @ w0 * inr00\n" \
"vmul.f32 q9, q5, q1 @ w0 * inr01\n" \
"vmul.f32 q10, q5, q2 @ w0 * inr02\n" \
"vmul.f32 q11, q5, q3 @ w0 * inr03\n" \
"vmla.f32 q8, q6, q1 @ w1 * inr01\n" \
"vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" \
"vmla.f32 q9, q6, q2 @ w1 * inr02\n" \
"vmla.f32 q10, q6, q3 @ w1 * inr03\n" \
"vmla.f32 q11, q6, q0 @ w1 * inr04\n" \
"vmla.f32 q8, q7, q2 @ w2 * inr02\n" \
"vmla.f32 q9, q7, q3 @ w2 * inr03\n" \
"vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" \
"vmla.f32 q10, q7, q0 @ w2 * inr04\n" \
"vmla.f32 q11, q7, q1 @ w2 * inr05\n" \
"vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" \
"vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" \
/* mul r1 with w0-w5, get out r0, r1 */
\
"vmul.f32 q12, q5, q2 @ w0 * inr10\n" \
"vmul.f32 q13, q5, q3 @ w0 * inr11\n" \
"vmul.f32 q14, q5, q0 @ w0 * inr12\n" \
"vmul.f32 q15, q5, q1 @ w0 * inr13\n" \
"vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" \
"vmla.f32 q8, q4, q2 @ w3 * inr10\n" \
"vmla.f32 q9, q4, q3 @ w3 * inr11\n" \
"vmla.f32 q10, q4, q0 @ w3 * inr12\n" \
"vmla.f32 q11, q4, q1 @ w3 * inr13\n" \
/* mul r1 with w1, w4, get out r1, r0 */
\
"vmla.f32 q8, q5, q3 @ w4 * inr11\n" \
"vmla.f32 q12, q6, q3 @ w1 * inr11\n" \
"vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" \
"vmla.f32 q9, q5, q0 @ w4 * inr12\n" \
"vmla.f32 q13, q6, q0 @ w1 * inr12\n" \
"vmla.f32 q10, q5, q1 @ w4 * inr13\n" \
"vmla.f32 q14, q6, q1 @ w1 * inr13\n" \
"vmla.f32 q11, q5, q2 @ w4 * inr14\n" \
"vmla.f32 q15, q6, q2 @ w1 * inr14\n" \
"vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" \
/* mul r1 with w2, w5, get out r1, r0 */
\
"vmla.f32 q12, q7, q0 @ w2 * inr12\n" \
"vmla.f32 q13, q7, q1 @ w2 * inr13\n" \
"vmla.f32 q8, q6, q0 @ w5 * inr12\n" \
"vmla.f32 q9, q6, q1 @ w5 * inr13\n" \
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" \
"vmla.f32 q14, q7, q2 @ w2 * inr14\n" \
"vmla.f32 q15, q7, q3 @ w2 * inr15\n" \
"vmla.f32 q10, q6, q2 @ w5 * inr14\n" \
"vmla.f32 q11, q6, q3 @ w5 * inr15\n" \
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" \
"vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" \
/* mul r2 with w3-w8, get out r0, r1 */
\
"vmla.f32 q12, q4, q0 @ w3 * inr20\n" \
"vmla.f32 q13, q4, q1 @ w3 * inr21\n" \
"vmla.f32 q14, q4, q2 @ w3 * inr22\n" \
"vmla.f32 q15, q4, q3 @ w3 * inr23\n" \
"vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" \
"vmla.f32 q8, q7, q0 @ w6 * inr20\n" \
"vmla.f32 q9, q7, q1 @ w6 * inr21\n" \
"vmla.f32 q10, q7, q2 @ w6 * inr22\n" \
"vmla.f32 q11, q7, q3 @ w6 * inr23\n" \
/* mul r2 with w4, w7, get out r1, r0 */
\
"vmla.f32 q8, q4, q1 @ w7 * inr21\n" \
"vmla.f32 q12, q5, q1 @ w4 * inr21\n" \
"vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" \
"vmla.f32 q9, q4, q2 @ w7 * inr22\n" \
"vmla.f32 q13, q5, q2 @ w4 * inr22\n" \
"vmla.f32 q10, q4, q3 @ w7 * inr23\n" \
"vmla.f32 q14, q5, q3 @ w4 * inr23\n" \
"vmla.f32 q11, q4, q0 @ w7 * inr24\n" \
"vmla.f32 q15, q5, q0 @ w4 * inr24\n" \
"vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" \
/* mul r1 with w5, w8, get out r1, r0 */
\
"vmla.f32 q12, q6, q2 @ w5 * inr22\n" \
"vmla.f32 q13, q6, q3 @ w5 * inr23\n" \
"vmla.f32 q8, q5, q2 @ w8 * inr22\n" \
"vmla.f32 q9, q5, q3 @ w8 * inr23\n" \
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" \
"ldr r4, [%[outl], #32] @ load bias addr to r4\n" \
"vmla.f32 q14, q6, q0 @ w5 * inr24\n" \
"vmla.f32 q15, q6, q1 @ w5 * inr25\n" \
"vmla.f32 q10, q5, q0 @ w8 * inr24\n" \
"vmla.f32 q11, q5, q1 @ w8 * inr25\n" \
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" \
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \
/* mul r3 with w6, w7, w8, get out r1 */
\
"vmla.f32 q12, q7, q2 @ w6 * inr30\n" \
"vmla.f32 q13, q7, q3 @ w6 * inr31\n" \
"vmla.f32 q14, q7, q0 @ w6 * inr32\n" \
"vmla.f32 q15, q7, q1 @ w6 * inr33\n" \
"vmla.f32 q12, q4, q3 @ w7 * inr31\n" \
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" \
"vld1.32 {d12-d13}, [r4] @ load bias\n" \
"vmla.f32 q13, q4, q0 @ w7 * inr32\n" \
"vmla.f32 q14, q4, q1 @ w7 * inr33\n" \
"vmla.f32 q15, q4, q2 @ w7 * inr34\n" \
"ldr r0, [%[outl]] @ load outc00 to r0\n" \
"vmla.f32 q12, q5, q0 @ w8 * inr32\n" \
"vmla.f32 q13, q5, q1 @ w8 * inr33\n" \
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \
"vmla.f32 q14, q5, q2 @ w8 * inr34\n" \
"vmla.f32 q15, q5, q3 @ w8 * inr35\n" \
"ldr r1, [%[outl], #4] @ load outc10 to r1\n" \
"vadd.f32 q8, q8, q6 @ r00 add bias\n" \
"vadd.f32 q9, q9, q6 @ r01 add bias\n" \
"vadd.f32 q10, q10, q6 @ r02 add bias\n" \
"vadd.f32 q11, q11, q6 @ r03 add bias\n" \
"ldr r2, [%[outl], #8] @ load outc20 to r2\n" \
"vadd.f32 q12, q12, q6 @ r10 add bias\n" \
"vadd.f32 q13, q13, q6 @ r11 add bias\n" \
"vadd.f32 q14, q14, q6 @ r12 add bias\n" \
"vadd.f32 q15, q15, q6 @ r13 add bias\n" \
"ldr r3, [%[outl], #12] @ load outc30 to r3\n" \
"vmov.u32 q7, #0 @ mov zero to q7\n"
#define RELU \
"vmax.f32 q8, q8, q7 @ r00 relu\n" \
"vmax.f32 q9, q9, q7 @ r01 relu\n" \
"vmax.f32 q10, q10, q7 @ r02 relu\n" \
"vmax.f32 q11, q11, q7 @ r03 relu\n" \
"vmax.f32 q12, q12, q7 @ r10 relu\n" \
"vmax.f32 q13, q13, q7 @ r11 relu\n" \
"vmax.f32 q14, q14, q7 @ r12 relu\n" \
"vmax.f32 q15, q15, q7 @ r13 relu\n"
#define RELU6 \
"ldr r4, [%[outl], #40] @ load six to r4\n" \
"vld1.32 {d12-d13}, [r4] @load data \n" \
"vmin.f32 q8, q8, q6 @ r00 relu\n" \
"vmin.f32 q9, q9, q6 @ r01 relu\n" \
"vmin.f32 q10, q10, q6 @ r02 relu\n" \
"vmin.f32 q11, q11, q6 @ r03 relu\n" \
"vmin.f32 q12, q12, q6 @ r10 relu\n" \
"vmin.f32 q13, q13, q6 @ r11 relu\n" \
"vmin.f32 q14, q14, q6 @ r12 relu\n" \
"vmin.f32 q15, q15, q6 @ r13 relu\n"
#define LEAKY_RELU \
"ldr r4, [%[outl], #44] @ load scale to r4\n" \
"vld1.32 {d12-d13}, [r4] @load data \n" \
"vcge.f32 q0, q8, q7 @ q0 > 0 \n" \
"vcge.f32 q1, q9, q7 @ q0 > 0 \n" \
"vmul.f32 q4, q8, q6 \n" \
"vmul.f32 q5, q9, q6 \n" \
"vcge.f32 q2, q10, q7 @ q0 > 0 \n" \
"vcge.f32 q3, q11, q7 @ q0 > 0 \n" \
"vbif q8, q4, q0 @ choose \n" \
"vbif q9, q5, q1 @ choose \n" \
"vmul.f32 q4, q10, q6 \n" \
"vmul.f32 q5, q11, q6 \n" \
"vbif q10, q4, q2 @ choose \n" \
"vbif q11, q5, q3 @ choose \n" \
"vcge.f32 q0, q12, q7 @ q0 > 0 \n" \
"vcge.f32 q1, q13, q7 @ q0 > 0 \n" \
"vmul.f32 q4, q12, q6 \n" \
"vmul.f32 q5, q13, q6 \n" \
"vcge.f32 q2, q14, q7 @ q0 > 0 \n" \
"vcge.f32 q3, q15, q7 @ q0 > 0 \n" \
"vbif q12, q4, q0 @ choose \n" \
"vbif q13, q5, q1 @ choose \n" \
"vmul.f32 q4, q14, q6 \n" \
"vmul.f32 q5, q15, q6 \n" \
"vbif q14, q4, q2 @ choose \n" \
"vbif q15, q5, q3 @ choose \n"
#define STORE \
"ldr r4, [%[outl], #16] @ load outc01 to r4\n" \
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" \
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \
"ldr r5, [%[outl], #20] @ load outc11 to r5\n" \
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" \
"cmp %[flag_mask], #0 @ cmp flag mask\n" \
"bne 2f\n" \
"vst1.32 {d16-d17}, [r0] @ save outc00\n" \
"vst1.32 {d18-d19}, [r1] @ save outc10\n" \
"vst1.32 {d20-d21}, [r2] @ save outc20\n" \
"vst1.32 {d22-d23}, [r3] @ save outc30\n" \
"vst1.32 {d24-d25}, [r4] @ save outc01\n" \
"vst1.32 {d26-d27}, [r5] @ save outc11\n" \
"ldr r0, [%[outl], #24] @ load outc21 to r0\n" \
"ldr r1, [%[outl], #28] @ load outc31 to r1\n" \
"vst1.32 {d28-d29}, [r0] @ save outc21\n" \
"vst1.32 {d30-d31}, [r1] @ save outc31\n" \
"b 3f @ branch end\n" \
"2: \n" \
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" \
"3: \n"
#endif
// clang-format on
void
act_switch_3x3s1
(
const
float
*
inr0
,
const
float
*
inr1
,
const
float
*
inr2
,
const
float
*
inr3
,
float
*
out0
,
const
float
*
weight_c
,
float
flag_mask
,
void
*
outl_ptr
,
float32x4_t
w0
,
float32x4_t
w1
,
float32x4_t
w2
,
float32x4_t
w3
,
float32x4_t
w4
,
float32x4_t
w5
,
float32x4_t
w6
,
float32x4_t
w7
,
float32x4_t
w8
,
float32x4_t
vbias
,
const
operators
::
ActivationParam
act_param
)
{
bool
has_active
=
act_param
.
has_active
;
if
(
has_active
)
{
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
COMPUTE
RELU
STORE
:
[
inr0
]
"+r"
(
inr0
),
[
inr1
]
"+r"
(
inr1
),
[
inr2
]
"+r"
(
inr2
),
[
inr3
]
"+r"
(
inr3
),
[
out
]
"+r"
(
out0
)
:
[
w0
]
"w"
(
w0
),
[
w1
]
"w"
(
w1
),
[
w2
]
"w"
(
w2
),
[
w3
]
"w"
(
w3
),
[
w4
]
"w"
(
w4
),
[
w5
]
"w"
(
w5
),
[
w6
]
"w"
(
w6
),
[
w7
]
"w"
(
w7
),
[
w8
]
"w"
(
w8
),
[
vbias
]
"w"
(
vbias
),
[
outl
]
"r"
(
outl_ptr
),
[
flag_mask
]
"r"
(
flag_mask
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"x0"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
);
#else
asm
volatile
(
COMPUTE
RELU
STORE
:
[
r0
]
"+r"
(
inr0
),
[
r1
]
"+r"
(
inr1
),
[
r2
]
"+r"
(
inr2
),
[
r3
]
"+r"
(
inr3
),
[
out0
]
"+r"
(
out0
),
[
wc0
]
"+r"
(
weight_c
)
:
[
flag_mask
]
"r"
(
flag_mask
),
[
outl
]
"r"
(
outl_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"r5"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
#ifdef __aarch64__
asm
volatile
(
COMPUTE
RELU
RELU6
STORE
:
[
inr0
]
"+r"
(
inr0
),
[
inr1
]
"+r"
(
inr1
),
[
inr2
]
"+r"
(
inr2
),
[
inr3
]
"+r"
(
inr3
),
[
out
]
"+r"
(
out0
)
:
[
w0
]
"w"
(
w0
),
[
w1
]
"w"
(
w1
),
[
w2
]
"w"
(
w2
),
[
w3
]
"w"
(
w3
),
[
w4
]
"w"
(
w4
),
[
w5
]
"w"
(
w5
),
[
w6
]
"w"
(
w6
),
[
w7
]
"w"
(
w7
),
[
w8
]
"w"
(
w8
),
[
vbias
]
"w"
(
vbias
),
[
outl
]
"r"
(
outl_ptr
),
[
flag_mask
]
"r"
(
flag_mask
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"x0"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
);
#else
asm
volatile
(
COMPUTE
RELU
RELU6
STORE
:
[
r0
]
"+r"
(
inr0
),
[
r1
]
"+r"
(
inr1
),
[
r2
]
"+r"
(
inr2
),
[
r3
]
"+r"
(
inr3
),
[
out0
]
"+r"
(
out0
),
[
wc0
]
"+r"
(
weight_c
)
:
[
flag_mask
]
"r"
(
flag_mask
),
[
outl
]
"r"
(
outl_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"r5"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
#ifdef __aarch64__
asm
volatile
(
COMPUTE
LEAKY_RELU
STORE
:
[
inr0
]
"+r"
(
inr0
),
[
inr1
]
"+r"
(
inr1
),
[
inr2
]
"+r"
(
inr2
),
[
inr3
]
"+r"
(
inr3
),
[
out
]
"+r"
(
out0
)
:
[
w0
]
"w"
(
w0
),
[
w1
]
"w"
(
w1
),
[
w2
]
"w"
(
w2
),
[
w3
]
"w"
(
w3
),
[
w4
]
"w"
(
w4
),
[
w5
]
"w"
(
w5
),
[
w6
]
"w"
(
w6
),
[
w7
]
"w"
(
w7
),
[
w8
]
"w"
(
w8
),
[
vbias
]
"w"
(
vbias
),
[
outl
]
"r"
(
outl_ptr
),
[
flag_mask
]
"r"
(
flag_mask
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"x0"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
);
#else
asm
volatile
(
COMPUTE
LEAKY_RELU
STORE
:
[
r0
]
"+r"
(
inr0
),
[
r1
]
"+r"
(
inr1
),
[
r2
]
"+r"
(
inr2
),
[
r3
]
"+r"
(
inr3
),
[
out0
]
"+r"
(
out0
),
[
wc0
]
"+r"
(
weight_c
)
:
[
flag_mask
]
"r"
(
flag_mask
),
[
outl
]
"r"
(
outl_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"r5"
);
#endif
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
COMPUTE
STORE
:
[
inr0
]
"+r"
(
inr0
),
[
inr1
]
"+r"
(
inr1
),
[
inr2
]
"+r"
(
inr2
),
[
inr3
]
"+r"
(
inr3
),
[
out
]
"+r"
(
out0
)
:
[
w0
]
"w"
(
w0
),
[
w1
]
"w"
(
w1
),
[
w2
]
"w"
(
w2
),
[
w3
]
"w"
(
w3
),
[
w4
]
"w"
(
w4
),
[
w5
]
"w"
(
w5
),
[
w6
]
"w"
(
w6
),
[
w7
]
"w"
(
w7
),
[
w8
]
"w"
(
w8
),
[
vbias
]
"w"
(
vbias
),
[
outl
]
"r"
(
outl_ptr
),
[
flag_mask
]
"r"
(
flag_mask
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"x0"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
);
#else
asm
volatile
(
COMPUTE
STORE
:
[
r0
]
"+r"
(
inr0
),
[
r1
]
"+r"
(
inr1
),
[
r2
]
"+r"
(
inr2
),
[
r3
]
"+r"
(
inr3
),
[
out0
]
"+r"
(
out0
),
[
wc0
]
"+r"
(
weight_c
)
:
[
flag_mask
]
"r"
(
flag_mask
),
[
outl
]
"r"
(
outl_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"r5"
);
#endif
}
}
void
conv_3x3s1_depthwise_fp32
(
const
float
*
i_data
,
float
*
o_data
,
int
bs
,
...
...
@@ -37,6 +816,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const
float
*
weights
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
int
threads
=
ctx
->
threads
();
...
...
@@ -78,6 +858,31 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
remain
=
remain
>
0
?
remain
:
0
;
int
row_len
=
win_round
*
out_c_block
;
float
six_ptr
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float
scale_ptr
[
4
]
=
{
1.
f
,
1.
f
,
1.
f
,
1.
f
};
float
relu_ptr
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
switch
(
act_param
.
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
six_ptr
[
0
]
=
act_param
.
Relu_clipped_coef
;
six_ptr
[
1
]
=
act_param
.
Relu_clipped_coef
;
six_ptr
[
2
]
=
act_param
.
Relu_clipped_coef
;
six_ptr
[
3
]
=
act_param
.
Relu_clipped_coef
;
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
scale_ptr
[
0
]
=
act_param
.
Leaky_relu_alpha
;
scale_ptr
[
1
]
=
act_param
.
Leaky_relu_alpha
;
scale_ptr
[
2
]
=
act_param
.
Leaky_relu_alpha
;
scale_ptr
[
3
]
=
act_param
.
Leaky_relu_alpha
;
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
.
active_type
)
<<
" fuse not support"
;
}
}
for
(
int
n
=
0
;
n
<
bs
;
++
n
)
{
const
float
*
din_batch
=
i_data
+
n
*
ic
*
size_in_channel
;
float
*
dout_batch
=
o_data
+
n
*
oc
*
size_out_channel
;
...
...
@@ -147,6 +952,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21
=
ptr_write
;
outc31
=
ptr_write
;
}
float
*
outl
[]
=
{
outc00
,
outc10
,
outc20
,
...
...
@@ -156,361 +962,54 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21
,
outc31
,
reinterpret_cast
<
float
*>
(
bias_local
),
reinterpret_cast
<
float
*>
(
flag_relu
)};
reinterpret_cast
<
float
*>
(
relu_ptr
),
reinterpret_cast
<
float
*>
(
six_ptr
),
reinterpret_cast
<
float
*>
(
scale_ptr
)};
void
*
outl_ptr
=
reinterpret_cast
<
void
*>
(
outl
);
for
(
int
w
=
0
;
w
<
w_loop
;
++
w
)
{
bool
flag_mask
=
(
w
==
w_loop
-
1
)
&&
flag_remain
;
float
*
out0
=
pre_out
;
// clang-format off
#ifdef __aarch64__
asm
volatile
(
"ldp q0, q1, [%[inr0]], #32
\n
"
/* load input r0*/
"ldp q6, q7, [%[inr1]], #32
\n
"
/* load input r1*/
"ldp q2, q3, [%[inr0]], #32
\n
"
/* load input r0*/
"ldp q8, q9, [%[inr1]], #32
\n
"
/* load input r1*/
"ldp q4, q5, [%[inr0]]
\n
"
/* load input r0*/
"ldp q10, q11, [%[inr1]]
\n
"
/* load input r1*/
/* r0, r1, mul w0, get out r0, r1 */
"fmul v15.4s , %[w0].4s, v0.4s
\n
"
/* outr00 = w0 * r0, 0*/
"fmul v16.4s , %[w0].4s, v1.4s
\n
"
/* outr01 = w0 * r0, 1*/
"fmul v17.4s , %[w0].4s, v2.4s
\n
"
/* outr02 = w0 * r0, 2*/
"fmul v18.4s , %[w0].4s, v3.4s
\n
"
/* outr03 = w0 * r0, 3*/
"fmul v19.4s , %[w0].4s, v6.4s
\n
"
/* outr10 = w0 * r1, 0*/
"fmul v20.4s , %[w0].4s, v7.4s
\n
"
/* outr11 = w0 * r1, 1*/
"fmul v21.4s , %[w0].4s, v8.4s
\n
"
/* outr12 = w0 * r1, 2*/
"fmul v22.4s , %[w0].4s, v9.4s
\n
"
/* outr13 = w0 * r1, 3*/
/* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v1.4s
\n
"
/* outr00 = w1 * r0[1]*/
"ldp q0, q1, [%[inr2]], #32
\n
"
/* load input r2*/
"fmla v16.4s , %[w1].4s, v2.4s
\n
"
/* outr01 = w1 * r0[2]*/
"fmla v17.4s , %[w1].4s, v3.4s
\n
"
/* outr02 = w1 * r0[3]*/
"fmla v18.4s , %[w1].4s, v4.4s
\n
"
/* outr03 = w1 * r0[4]*/
"fmla v19.4s , %[w1].4s, v7.4s
\n
"
/* outr10 = w1 * r1[1]*/
"fmla v20.4s , %[w1].4s, v8.4s
\n
"
/* outr11 = w1 * r1[2]*/
"fmla v21.4s , %[w1].4s, v9.4s
\n
"
/* outr12 = w1 * r1[3]*/
"fmla v22.4s , %[w1].4s, v10.4s
\n
"
/* outr13 = w1 * r1[4]*/
/* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v2.4s
\n
"
/* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v3.4s
\n
"
/* outr01 = w2 * r0[3]*/
"ldp q2, q3, [%[inr2]], #32
\n
"
/* load input r2*/
"fmla v17.4s , %[w2].4s, v4.4s
\n
"
/* outr02 = w2 * r0[4]*/
"fmla v18.4s , %[w2].4s, v5.4s
\n
"
/* outr03 = w2 * r0[5]*/
"ldp q4, q5, [%[inr2]]
\n
"
/* load input r2*/
"fmla v19.4s , %[w2].4s, v8.4s
\n
"
/* outr10 = w2 * r1[2]*/
"fmla v20.4s , %[w2].4s, v9.4s
\n
"
/* outr11 = w2 * r1[3]*/
"fmla v21.4s , %[w2].4s, v10.4s
\n
"
/* outr12 = w2 * r1[4]*/
"fmla v22.4s , %[w2].4s, v11.4s
\n
"
/* outr13 = w2 * r1[5]*/
/* r1, r2, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v6.4s
\n
"
/* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v7.4s
\n
"
/* outr01 = w3 * r1[1]*/
"fmla v17.4s , %[w3].4s, v8.4s
\n
"
/* outr02 = w3 * r1[2]*/
"fmla v18.4s , %[w3].4s, v9.4s
\n
"
/* outr03 = w3 * r1[3]*/
"fmla v19.4s , %[w3].4s, v0.4s
\n
"
/* outr10 = w3 * r2[0]*/
"fmla v20.4s , %[w3].4s, v1.4s
\n
"
/* outr11 = w3 * r2[1]*/
"fmla v21.4s , %[w3].4s, v2.4s
\n
"
/* outr12 = w3 * r2[2]*/
"fmla v22.4s , %[w3].4s, v3.4s
\n
"
/* outr13 = w3 * r2[3]*/
/* r1, r2, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v7.4s
\n
"
/* outr00 = w4 * r1[1]*/
"ldp q6, q7, [%[inr3]], #32
\n
"
/* load input r3*/
"fmla v16.4s , %[w4].4s, v8.4s
\n
"
/* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v9.4s
\n
"
/* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v10.4s
\n
"
/* outr03 = w4 * r1[4]*/
"ldp x0, x1, [%[outl]]
\n
"
"fmla v19.4s , %[w4].4s, v1.4s
\n
"
/* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v2.4s
\n
"
/* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v3.4s
\n
"
/* outr12 = w4 * r2[3]*/
"fmla v22.4s , %[w4].4s, v4.4s
\n
"
/* outr13 = w4 * r2[4]*/
/* r1, r2, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v8.4s
\n
"
/* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v9.4s
\n
"
/* outr01 = w5 * r1[3]*/
"ldp q8, q9, [%[inr3]], #32
\n
"
/* load input r3*/
"fmla v17.4s , %[w5].4s, v10.4s
\n
"
/* outr02 = w5 * r1[4]*/
"fmla v18.4s , %[w5].4s, v11.4s
\n
"
/* outr03 = w5 * r1[5]*/
"ldp q10, q11, [%[inr3]]
\n
"
/* load input r3*/
"fmla v19.4s , %[w5].4s, v2.4s
\n
"
/* outr10 = w5 * r2[2]*/
"fmla v20.4s , %[w5].4s, v3.4s
\n
"
/* outr11 = w5 * r2[3]*/
"fmla v21.4s , %[w5].4s, v4.4s
\n
"
/* outr12 = w5 * r2[4]*/
"fmla v22.4s , %[w5].4s, v5.4s
\n
"
/* outr13 = w5 * r2[5]*/
/* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v0.4s
\n
"
/* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v1.4s
\n
"
/* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v2.4s
\n
"
/* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v3.4s
\n
"
/* outr03 = w6 * r2[3]*/
"ldp x2, x3, [%[outl], #16]
\n
"
"fmla v19.4s , %[w6].4s, v6.4s
\n
"
/* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v7.4s
\n
"
/* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v8.4s
\n
"
/* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w6].4s, v9.4s
\n
"
/* outr13 = w6 * r3[3]*/
/* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w7].4s, v1.4s
\n
"
/* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v2.4s
\n
"
/* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v3.4s
\n
"
/* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v4.4s
\n
"
/* outr03 = w7 * r2[4]*/
"ldp x4, x5, [%[outl], #32]
\n
"
"fmla v19.4s , %[w7].4s, v7.4s
\n
"
/* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v8.4s
\n
"
/* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v9.4s
\n
"
/* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w7].4s, v10.4s
\n
"
/* outr13 = w7 * r3[4]*/
/* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v2.4s
\n
"
/* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v3.4s
\n
"
/* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v4.4s
\n
"
/* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.4s
\n
"
/* outr03 = w8 * r2[1]*/
"ldp x6, x7, [%[outl], #48]
\n
"
"fmla v19.4s , %[w8].4s, v8.4s
\n
"
/* outr10 = w8 * r3[2]*/
"fmla v20.4s , %[w8].4s, v9.4s
\n
"
/* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w8].4s, v10.4s
\n
"
/* outr12 = w8 * r3[0]*/
"fmla v22.4s , %[w8].4s, v11.4s
\n
"
/* outr13 = w8 * r3[1]*/
"fadd v15.4s, v15.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v16.4s, v16.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v17.4s, v17.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v18.4s, v18.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v19.4s, v19.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v20.4s, v20.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v21.4s, v21.4s, %[vbias].4s
\n
"
/* add bias */
"fadd v22.4s, v22.4s, %[vbias].4s
\n
"
/* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s
\n
"
/* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s
\n
"
/* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s
\n
"
/* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s
\n
"
/* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s
\n
"
/* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s
\n
"
/* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s
\n
"
/* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s
\n
"
/* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d
\n
"
/* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d
\n
"
/* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d
\n
"
/* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d
\n
"
/* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d
\n
"
/* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d
\n
"
/* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d
\n
"
/* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d
\n
"
/* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f
\n
"
/* skip relu*/
"movi v0.4s, #0
\n
"
/* for relu */
"fmax v15.4s, v15.4s, v0.4s
\n
"
"fmax v16.4s, v16.4s, v0.4s
\n
"
"fmax v17.4s, v17.4s, v0.4s
\n
"
"fmax v18.4s, v18.4s, v0.4s
\n
"
"fmax v19.4s, v19.4s, v0.4s
\n
"
"fmax v20.4s, v20.4s, v0.4s
\n
"
"fmax v21.4s, v21.4s, v0.4s
\n
"
"fmax v22.4s, v22.4s, v0.4s
\n
"
"0:
\n
"
"cbnz %w[flag_mask], 1f
\n
"
"str q15, [x0]
\n
"
/* save outc00 */
"str q16, [x4]
\n
"
/* save outc01 */
"str q17, [x1]
\n
"
/* save outc10 */
"str q18, [x5]
\n
"
/* save outc11 */
"str q19, [x2]
\n
"
/* save outc20 */
"str q20, [x6]
\n
"
/* save outc21 */
"str q21, [x3]
\n
"
/* save outc30 */
"str q22, [x7]
\n
"
/* save outc31 */
"b 2f
\n
"
"1:
\n
"
"str q15, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q17, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q19, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q21, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q16, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q18, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q20, [%[out]], #16
\n
"
/* save remain to pre_out */
"str q22, [%[out]], #16
\n
"
/* save remain to pre_out */
"2:
\n
"
:
[
inr0
]
"+r"
(
inr0
),
[
inr1
]
"+r"
(
inr1
),
[
inr2
]
"+r"
(
inr2
),
[
inr3
]
"+r"
(
inr3
),
[
out
]
"+r"
(
out0
)
:
[
w0
]
"w"
(
w0
),
[
w1
]
"w"
(
w1
),
[
w2
]
"w"
(
w2
),
[
w3
]
"w"
(
w3
),
[
w4
]
"w"
(
w4
),
[
w5
]
"w"
(
w5
),
[
w6
]
"w"
(
w6
),
[
w7
]
"w"
(
w7
),
[
w8
]
"w"
(
w8
),
[
vbias
]
"w"
(
vbias
),
[
outl
]
"r"
(
outl_ptr
),
[
flag_mask
]
"r"
(
flag_mask
),
[
flag_relu
]
"r"
(
flag_relu
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"x0"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
);
act_switch_3x3s1
(
inr0
,
inr1
,
inr2
,
inr3
,
out0
,
weight_c
,
flag_mask
,
outl_ptr
,
w0
,
w1
,
w2
,
w3
,
w4
,
w5
,
w6
,
w7
,
w8
,
vbias
,
act_param
);
#else
asm
volatile
(
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6
\n
"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7
\n
"
/* load r0, r1 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1
\n
"
"vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3
\n
"
/* main loop */
"0: @ main loop
\n
"
/* mul r0 with w0, w1, w2, get out r0 */
"vmul.f32 q8, q5, q0 @ w0 * inr00
\n
"
"vmul.f32 q9, q5, q1 @ w0 * inr01
\n
"
"vmul.f32 q10, q5, q2 @ w0 * inr02
\n
"
"vmul.f32 q11, q5, q3 @ w0 * inr03
\n
"
"vmla.f32 q8, q6, q1 @ w1 * inr01
\n
"
"vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1
\n
"
"vmla.f32 q9, q6, q2 @ w1 * inr02
\n
"
"vmla.f32 q10, q6, q3 @ w1 * inr03
\n
"
"vmla.f32 q11, q6, q0 @ w1 * inr04
\n
"
"vmla.f32 q8, q7, q2 @ w2 * inr02
\n
"
"vmla.f32 q9, q7, q3 @ w2 * inr03
\n
"
"vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3
\n
"
"vmla.f32 q10, q7, q0 @ w2 * inr04
\n
"
"vmla.f32 q11, q7, q1 @ w2 * inr05
\n
"
"vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1
\n
"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4
\n
"
/* mul r1 with w0-w5, get out r0, r1 */
"vmul.f32 q12, q5, q2 @ w0 * inr10
\n
"
"vmul.f32 q13, q5, q3 @ w0 * inr11
\n
"
"vmul.f32 q14, q5, q0 @ w0 * inr12
\n
"
"vmul.f32 q15, q5, q1 @ w0 * inr13
\n
"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5
\n
"
"vmla.f32 q8, q4, q2 @ w3 * inr10
\n
"
"vmla.f32 q9, q4, q3 @ w3 * inr11
\n
"
"vmla.f32 q10, q4, q0 @ w3 * inr12
\n
"
"vmla.f32 q11, q4, q1 @ w3 * inr13
\n
"
/* mul r1 with w1, w4, get out r1, r0 */
"vmla.f32 q8, q5, q3 @ w4 * inr11
\n
"
"vmla.f32 q12, q6, q3 @ w1 * inr11
\n
"
"vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3
\n
"
"vmla.f32 q9, q5, q0 @ w4 * inr12
\n
"
"vmla.f32 q13, q6, q0 @ w1 * inr12
\n
"
"vmla.f32 q10, q5, q1 @ w4 * inr13
\n
"
"vmla.f32 q14, q6, q1 @ w1 * inr13
\n
"
"vmla.f32 q11, q5, q2 @ w4 * inr14
\n
"
"vmla.f32 q15, q6, q2 @ w1 * inr14
\n
"
"vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6
\n
"
/* mul r1 with w2, w5, get out r1, r0 */
"vmla.f32 q12, q7, q0 @ w2 * inr12
\n
"
"vmla.f32 q13, q7, q1 @ w2 * inr13
\n
"
"vmla.f32 q8, q6, q0 @ w5 * inr12
\n
"
"vmla.f32 q9, q6, q1 @ w5 * inr13
\n
"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1
\n
"
"vmla.f32 q14, q7, q2 @ w2 * inr14
\n
"
"vmla.f32 q15, q7, q3 @ w2 * inr15
\n
"
"vmla.f32 q10, q6, q2 @ w5 * inr14
\n
"
"vmla.f32 q11, q6, q3 @ w5 * inr15
\n
"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1
\n
"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7
\n
"
/* mul r2 with w3-w8, get out r0, r1 */
"vmla.f32 q12, q4, q0 @ w3 * inr20
\n
"
"vmla.f32 q13, q4, q1 @ w3 * inr21
\n
"
"vmla.f32 q14, q4, q2 @ w3 * inr22
\n
"
"vmla.f32 q15, q4, q3 @ w3 * inr23
\n
"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4
\n
"
"vmla.f32 q8, q7, q0 @ w6 * inr20
\n
"
"vmla.f32 q9, q7, q1 @ w6 * inr21
\n
"
"vmla.f32 q10, q7, q2 @ w6 * inr22
\n
"
"vmla.f32 q11, q7, q3 @ w6 * inr23
\n
"
/* mul r2 with w4, w7, get out r1, r0 */
"vmla.f32 q8, q4, q1 @ w7 * inr21
\n
"
"vmla.f32 q12, q5, q1 @ w4 * inr21
\n
"
"vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1
\n
"
"vmla.f32 q9, q4, q2 @ w7 * inr22
\n
"
"vmla.f32 q13, q5, q2 @ w4 * inr22
\n
"
"vmla.f32 q10, q4, q3 @ w7 * inr23
\n
"
"vmla.f32 q14, q5, q3 @ w4 * inr23
\n
"
"vmla.f32 q11, q4, q0 @ w7 * inr24
\n
"
"vmla.f32 q15, q5, q0 @ w4 * inr24
\n
"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5
\n
"
/* mul r1 with w5, w8, get out r1, r0 */
"vmla.f32 q12, q6, q2 @ w5 * inr22
\n
"
"vmla.f32 q13, q6, q3 @ w5 * inr23
\n
"
"vmla.f32 q8, q5, q2 @ w8 * inr22
\n
"
"vmla.f32 q9, q5, q3 @ w8 * inr23
\n
"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3
\n
"
"ldr r4, [%[outl], #32] @ load bias addr to r4
\n
"
"vmla.f32 q14, q6, q0 @ w5 * inr24
\n
"
"vmla.f32 q15, q6, q1 @ w5 * inr25
\n
"
"vmla.f32 q10, q5, q0 @ w8 * inr24
\n
"
"vmla.f32 q11, q5, q1 @ w8 * inr25
\n
"
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1
\n
"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address
\n
"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q7, q2 @ w6 * inr30
\n
"
"vmla.f32 q13, q7, q3 @ w6 * inr31
\n
"
"vmla.f32 q14, q7, q0 @ w6 * inr32
\n
"
"vmla.f32 q15, q7, q1 @ w6 * inr33
\n
"
"vmla.f32 q12, q4, q3 @ w7 * inr31
\n
"
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3
\n
"
"vld1.32 {d12-d13}, [r4] @ load bias
\n
"
"vmla.f32 q13, q4, q0 @ w7 * inr32
\n
"
"vmla.f32 q14, q4, q1 @ w7 * inr33
\n
"
"vmla.f32 q15, q4, q2 @ w7 * inr34
\n
"
"ldr r0, [%[outl]] @ load outc00 to r0
\n
"
"vmla.f32 q12, q5, q0 @ w8 * inr32
\n
"
"vmla.f32 q13, q5, q1 @ w8 * inr33
\n
"
"ldr r5, [%[outl], #36] @ load flag_relu to r5
\n
"
"vmla.f32 q14, q5, q2 @ w8 * inr34
\n
"
"vmla.f32 q15, q5, q3 @ w8 * inr35
\n
"
"ldr r1, [%[outl], #4] @ load outc10 to r1
\n
"
"vadd.f32 q8, q8, q6 @ r00 add bias
\n
"
"vadd.f32 q9, q9, q6 @ r01 add bias
\n
"
"vadd.f32 q10, q10, q6 @ r02 add bias
\n
"
"vadd.f32 q11, q11, q6 @ r03 add bias
\n
"
"ldr r2, [%[outl], #8] @ load outc20 to r2
\n
"
"vadd.f32 q12, q12, q6 @ r10 add bias
\n
"
"vadd.f32 q13, q13, q6 @ r11 add bias
\n
"
"vadd.f32 q14, q14, q6 @ r12 add bias
\n
"
"vadd.f32 q15, q15, q6 @ r13 add bias
\n
"
"ldr r3, [%[outl], #12] @ load outc30 to r3
\n
"
"vmov.u32 q7, #0 @ mov zero to q7
\n
"
"cmp r5, #0 @ cmp flag relu
\n
"
"beq 1f @ skip relu
\n
"
"vmax.f32 q8, q8, q7 @ r00 relu
\n
"
"vmax.f32 q9, q9, q7 @ r01 relu
\n
"
"vmax.f32 q10, q10, q7 @ r02 relu
\n
"
"vmax.f32 q11, q11, q7 @ r03 relu
\n
"
"vmax.f32 q12, q12, q7 @ r10 relu
\n
"
"vmax.f32 q13, q13, q7 @ r11 relu
\n
"
"vmax.f32 q14, q14, q7 @ r12 relu
\n
"
"vmax.f32 q15, q15, q7 @ r13 relu
\n
"
"1:
\n
"
"ldr r4, [%[outl], #16] @ load outc01 to r4
\n
"
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1
\n
"
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3
\n
"
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1
\n
"
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3
\n
"
"ldr r5, [%[outl], #20] @ load outc11 to r5
\n
"
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3
\n
"
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3
\n
"
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3
\n
"
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3
\n
"
"cmp %[flag_mask], #0 @ cmp flag mask
\n
"
"bne 2f
\n
"
"vst1.32 {d16-d17}, [r0] @ save outc00
\n
"
"vst1.32 {d18-d19}, [r1] @ save outc10
\n
"
"vst1.32 {d20-d21}, [r2] @ save outc20
\n
"
"vst1.32 {d22-d23}, [r3] @ save outc30
\n
"
"vst1.32 {d24-d25}, [r4] @ save outc01
\n
"
"vst1.32 {d26-d27}, [r5] @ save outc11
\n
"
"ldr r0, [%[outl], #24] @ load outc21 to r0
\n
"
"ldr r1, [%[outl], #28] @ load outc31 to r1
\n
"
"vst1.32 {d28-d29}, [r0] @ save outc21
\n
"
"vst1.32 {d30-d31}, [r1] @ save outc31
\n
"
"b 3f @ branch end
\n
"
"2:
\n
"
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out
\n
"
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out
\n
"
"3:
\n
"
:
[
r0
]
"+r"
(
inr0
),
[
r1
]
"+r"
(
inr1
),
[
r2
]
"+r"
(
inr2
),
[
r3
]
"+r"
(
inr3
),
[
out0
]
"+r"
(
out0
),
[
wc0
]
"+r"
(
weight_c
)
:
[
flag_mask
]
"r"
(
flag_mask
),
[
outl
]
"r"
(
outl_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"r5"
);
#endif // __arch64__
// clang-format on
act_switch_3x3s1
(
inr0
,
inr1
,
inr2
,
inr3
,
out0
,
weight_c
,
flag_mask
,
outl_ptr
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
vbias
,
act_param
);
#endif
outl
[
0
]
+=
4
;
outl
[
1
]
+=
4
;
outl
[
2
]
+=
4
;
...
...
@@ -519,6 +1018,10 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outl
[
5
]
+=
4
;
outl
[
6
]
+=
4
;
outl
[
7
]
+=
4
;
inr0
+=
16
;
inr1
+=
16
;
inr2
+=
16
;
inr3
+=
16
;
if
(
flag_mask
)
{
memcpy
(
outl
[
0
]
-
4
,
pre_out
,
remain
*
sizeof
(
float
));
memcpy
(
outl
[
1
]
-
4
,
pre_out
+
4
,
remain
*
sizeof
(
float
));
...
...
lite/backends/arm/math/conv3x3s2_direct_fp32.cc
浏览文件 @
3455ab0a
...
...
@@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto
paddings
=
*
param
.
paddings
;
auto
act_param
=
param
.
activation_param
;
const
int
threads
=
ctx
->
threads
();
int
l2_size
=
ctx
->
llc_size
()
/
sizeof
(
float
);
const
int
pad_w
=
paddings
[
2
];
...
...
@@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh
,
ow
,
flag_relu
,
ptr_write
);
ptr_write
,
&
act_param
);
}
#pragma omp parallel for num_threads(threads)
...
...
@@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh
,
ow
,
flag_relu
,
ptr_write
);
ptr_write
,
&
act_param
);
}
}
}
...
...
lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc
浏览文件 @
3455ab0a
...
...
@@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n"
#define LEFT_RESULT_S2 \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
...
...
@@ -244,53 +242,52 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"blt 1f \n"
#define MID_COMPUTE_S2 \
"2: \n"
/* r0 */
\
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n"
/* r1 */
\
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
/* r2 */
\
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
/* r3 */
\
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
#define MID_COMPUTE_S2 \
"2: \n"
/* r0 */
\
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n"
/* r1 */
\
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
/* r2 */
\
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
/* r3 */
\
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n"
#define MID_RESULT_S2 \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
...
...
@@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" \
"ld1 {v1.4s}, [%[outptr1]] \n"
"ld1 {v1.4s}, [%[outptr1]] \n"
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n"
#define RIGHT_RESULT_S2 \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
...
...
@@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"4: \n"
#define LEFT_RESULT_S2_RELU \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
...
...
@@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"blt 1f \n"
#define MID_RESULT_S2_RELU \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"fmax v16.4s, v16.4s, %[vzero].4s \n"
/* relu */
\
\
"fadd v17.4s, v17.4s, v13.4s \n" \
...
...
@@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
/* r4 */
\
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n"
/* relu */
\
\
"fadd v17.4s, v17.4s, v13.4s \n" \
...
...
lite/backends/arm/math/conv_block_utils.h
浏览文件 @
3455ab0a
...
...
@@ -20,6 +20,7 @@
#include "lite/backends/arm/math/sgemm.h"
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace
paddle
{
...
...
@@ -28,6 +29,7 @@ namespace arm {
namespace
math
{
#define LITEMAX(a, b) ((a) > (b) ? (a) : (b))
#define LITEMIN(a, b) ((a) < (b) ? (a) : (b))
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template
<
PrecisionType
Ptype
>
...
...
@@ -589,7 +591,238 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
}
}
}
// clang-format off
#ifdef __aarch64__
#define NCHWC1_TRANS_FP32_COMPUTE \
"ldr q0, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"ldr q1, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"ldr q2, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"ldr q3, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"movi v20.4s, #0 \n"
/* for relu */
\
"1: \n"
/* main loop*/
#define NCHWC1_TRANS_FP32_RELU \
"fmax v0.4s, v0.4s, v20.4s \n"
/*relu*/
\
"fmax v1.4s, v1.4s, v20.4s \n"
/*relu*/
\
"fmax v2.4s, v2.4s, v20.4s \n"
/*relu*/
\
"fmax v3.4s, v3.4s, v20.4s \n"
/*relu*/
#define NCHWC1_TRANS_FP32_RELU6 \
"fmin v0.4s, v0.4s, %[six].4s \n"
/* relu6 */
\
"fmin v1.4s, v1.4s, %[six].4s \n"
/* relu6 */
\
"fmin v2.4s, v2.4s, %[six].4s \n"
/* relu6 */
\
"fmin v3.4s, v3.4s, %[six].4s \n"
/* relu6 */
#define NCHWC1_TRANS_FP32_LEAKY_RELU \
"cmhs v4.4s, v0.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v5.4s, v1.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v6.4s, v2.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v7.4s, v3.4s, v20.4s \n"
/* vcgeq_u32 */
\
"fmul v8.4s, v0.4s, %[scale].4s \n"
/* mul */
\
"fmul v9.4s, v1.4s, %[scale].4s \n"
/* mul */
\
"fmul v10.4s, v2.4s, %[scale].4s \n"
/* mul */
\
"fmul v11.4s, v3.4s, %[scale].4s \n"
/* mul */
\
"bif v0.16b, v8.16b, v4.16b \n"
/* choose*/
\
"bif v1.16b, v9.16b, v5.16b \n"
/* choose*/
\
"bif v2.16b, v10.16b, v6.16b \n"
/* choose*/
\
"bif v3.16b, v11.16b, v7.16b \n"
/* choose*/
#define NCHWC1_TRANS_FP32_STORE \
"subs %w[cnt], %w[cnt], #1 \n"
/* loop count -1*/
\
\
"str q0, [%[doutc0r0]], #16 \n"
/* store c0r0*/
\
"str q1, [%[doutc0r0]], #16 \n"
/* store c0r0*/
\
"ldr q0, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"ldr q1, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"str q2, [%[doutc0r0]], #16 \n"
/* store c0r0*/
\
"str q3, [%[doutc0r0]], #16 \n"
/* store c2r0*/
\
"ldr q2, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
"ldr q3, [%[ptr_din]], #16 \n"
/* load data, c0r0, c1r0, c0r1*/
\
\
"bne 1b \n"
/* jump to main loop*/
#else
#define NCHWC1_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0 \n" \
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data, c0r0 \n" \
"vmov.u32 q15, #0 @ dump zero\n" \
"1: @ main loop\n"
#define NCHWC1_TRANS_FP32_RELU \
"vmax.f32 q0, q0, q15 @ relu\n" \
"vmax.f32 q1, q1, q15 @ relu\n" \
"vmax.f32 q2, q2, q15 @ relu\n" \
"vmax.f32 q3, q3, q15 @ relu\n"
#define NCHWC1_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n" \
"vmin.f32 q2, q2, %q[six] @ relu6 \n" \
"vmin.f32 q3, q3, %q[six] @ relu6 \n"
#define NCHWC1_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q7, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q8, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vmul.f32 q11, q2, %q[scale] \n" \
"vmul.f32 q12, q3, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n" \
"vbif q2, q11, q7 @ choose \n" \
"vbif q3, q12, q8 @ choose \n"
#define NCHWC1_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
\
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \
\
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \
\
"bne 1b @ jump to main loop\n"
#endif
// clang-format on
inline
void
act_switch_c1_fp32
(
const
float
*
din_ptr
,
float
*
doutc0_ptr
,
int
cnt_loop
,
const
operators
::
ActivationParam
*
act_param
)
{
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float32x4_t
six
=
vdupq_n_f32
(
act_param
->
Relu_clipped_coef
);
float32x4_t
scale
=
vdupq_n_f32
(
act_param
->
Leaky_relu_alpha
);
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v20"
);
#else
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_RELU6
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
six
]
"w"
(
six
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v20"
);
#else
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_RELU6
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
six
]
"w"
(
six
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_LEAKY_RELU
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
scale
]
"w"
(
scale
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v20"
);
#else
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_LEAKY_RELU
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
scale
]
"w"
(
scale
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q15"
);
#endif
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v20"
);
#else
asm
volatile
(
NCHWC1_TRANS_FP32_COMPUTE
NCHWC1_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
}
/*wirte result in outputs
* input din: [n, c, h, w], output dout: [n, c, h, w]
*/
...
...
@@ -605,13 +838,14 @@ inline bool write_to_output_c1_fp32(const float* din,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
)
{
float
*
trash_ptr
,
operators
::
ActivationParam
*
act_param
)
{
if
(
cs
>
channel
)
{
return
true
;
}
const
int
c1
=
1
;
const
int
w4
=
4
;
const
int
w4
=
16
;
int
size_c_out
=
width
*
height
;
...
...
@@ -623,98 +857,53 @@ inline bool write_to_output_c1_fp32(const float* din,
int
w_round
=
we
-
ws
;
int
cnt
=
(
width
-
ws
)
/
w4
;
int
remain
=
(
width
-
ws
)
%
w4
;
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
const
float
*
din_hei_ptr
=
ptr_din
+
i
*
w_round
*
c1
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
if
(
flag_relu
)
{
#ifdef __aarch64__
asm
volatile
(
"ldr q0, [%[ptr_din]], #16
\n
"
/* load data, c0r0, c0r1, c0r2,
c0r3 */
"movi v20.4s, #0
\n
"
/* for relu */
"1:
\n
"
/* main loop*/
"fmax v1.4s, v0.4s, v20.4s
\n
"
/*relu*/
"ldr q0, [%[ptr_din]], #16
\n
"
/* load data, c0r0, c0r1, c0r2,
c0r3 */
"subs %w[cnt], %w[cnt], #1
\n
"
/* loop count -1*/
"str q1, [%[doutc0r0]], #16
\n
"
/* store c0r0*/
"bne 1b
\n
"
/* jump to main loop*/
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_hei_ptr
)
:
:
"v0"
,
"v1"
,
"v20"
);
#else
asm
volatile
(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, "
"c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3
\n
"
"vmov.u32 q15, #0 @ dump zero
\n
"
"1: @ main loop
\n
"
"vmax.f32 q1, q0, q15 @ relu
\n
"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data
\n
"
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], %[cnt], #1 @ loop count - 1
\n
"
"bne 1b @ jump to main loop
\n
"
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_hei_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q15"
);
#endif
}
else
{
#ifdef __aarch64__
asm
volatile
(
"ldr q0, [%[ptr_din]], #16
\n
"
/* load data, c0r0, c0r1, c0r2,
c0r3 */
"1:
\n
"
/* main loop*/
"str q0, [%[doutc0r0]], #16
\n
"
/* store c2r0*/
"subs %w[cnt], %w[cnt], #1
\n
"
/* loop count -1*/
"ldr q0, [%[ptr_din]], #16
\n
"
/* load data, c0r0, c0r1, c0r2,
c0r3 */
"bne 1b
\n
"
/* jump to main loop*/
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_hei_ptr
)
:
:
"v0"
);
#else
asm
volatile
(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, "
"c0r1, c0r2, c0r3
\n
"
"1: @ main loop
\n
"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add "
"pointer
\n
"
"subs %[cnt], %[cnt], #1 @ loop count - 1
\n
"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data
\n
"
"bne 1b @ jump to main loop
\n
"
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
ptr_din
]
"+r"
(
din_hei_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
);
#endif
}
act_switch_c1_fp32
(
din_hei_ptr
,
doutc0_ptr
,
cnt_loop
,
act_param
);
}
if
(
we
>
width
)
{
if
(
remain
>
0
)
{
int
offset
=
i
*
w_round
*
c1
+
c1
*
w4
*
cnt
;
din_hei_ptr
=
ptr_din
+
offset
;
int
j
=
we
-
w4
;
if
(
flag_relu
)
{
for
(;
j
<
width
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
din_hei_ptr
++
;
doutc0_ptr
+=
w4
*
cnt
;
int
j
=
w4
*
cnt
;
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float
six
=
act_param
->
Relu_clipped_coef
;
float
scale
=
act_param
->
Leaky_relu_alpha
;
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
for
(;
j
<
width
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
din_hei_ptr
++
;
}
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
for
(;
j
<
width
;
++
j
)
{
float
tmp
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc0_ptr
++
)
=
LITEMIN
(
tmp
,
six
);
din_hei_ptr
++
;
}
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
for
(;
j
<
width
;
++
j
)
{
if
(
din_hei_ptr
[
0
]
>=
0
)
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
];
}
else
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
]
*
scale
;
}
din_hei_ptr
++
;
}
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
for
(;
j
<
width
;
++
j
)
{
...
...
@@ -725,6 +914,7 @@ inline bool write_to_output_c1_fp32(const float* din,
}
return
true
;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC2_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n"
/* load data, c0r0, c1r0, c0r1*/
\
...
...
@@ -740,6 +930,18 @@ inline bool write_to_output_c1_fp32(const float* din,
"fmax v2.4s, v4.4s, v20.4s \n"
/*relu*/
\
"fmax v3.4s, v5.4s, v20.4s \n"
/*relu*/
#define NCHWC2_TRANS_FP32_RELU6 \
"fmin v2.4s, v2.4s, %[six].4s \n"
/* relu6 */
\
"fmin v3.4s, v3.4s, %[six].4s \n"
/* relu6 */
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"cmhs v6.4s, v2.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v7.4s, v3.4s, v20.4s \n"
/* vcgeq_u32 */
\
"fmul v4.4s, v2.4s, %[scale].4s \n"
/* mul */
\
"fmul v5.4s, v3.4s, %[scale].4s \n"
/* mul */
\
"bif v2.16b, v4.16b, v6.16b \n"
/* choose*/
\
"bif v3.16b, v5.16b, v7.16b \n"
/* choose*/
#define NCHWC2_TRANS_FP32_STORE \
"subs %w[cnt], %w[cnt], #1 \n"
/* loop count -1*/
\
\
...
...
@@ -749,8 +951,7 @@ inline bool write_to_output_c1_fp32(const float* din,
"bne 1b \n"
/* jump to main loop*/
#else
#define NCHWC2_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " \
"c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, c1r0 \n" \
"vmov.u32 q15, #0 @ dump zero\n" \
"1: @ main loop\n" \
"vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \
...
...
@@ -764,11 +965,21 @@ inline bool write_to_output_c1_fp32(const float* din,
"vmax.f32 q0, q0, q15 @ relu\n" \
"vmax.f32 q1, q1, q15 @ relu\n"
#define NCHWC2_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n"
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n"
#define NCHWC2_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \
"pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " \
"pointer\n" \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \
\
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
\
...
...
@@ -776,6 +987,151 @@ inline bool write_to_output_c1_fp32(const float* din,
\
"bne 1b @ jump to main loop\n"
#endif
// clang-format on
inline
void
act_switch_c2_fp32
(
const
float
*
din_ptr
,
float
*
doutc0_ptr
,
float
*
doutc1_ptr
,
int
cnt_loop
,
const
operators
::
ActivationParam
*
act_param
)
{
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float32x4_t
six
=
vdupq_n_f32
(
act_param
->
Relu_clipped_coef
);
float32x4_t
scale
=
vdupq_n_f32
(
act_param
->
Leaky_relu_alpha
);
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v20"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_RELU6
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
six
]
"w"
(
six
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v20"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_RELU6
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
six
]
"w"
(
six
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_LEAKY_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
scale
]
"w"
(
scale
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v20"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_LEAKY_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
scale
]
"w"
(
scale
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q15"
);
#endif
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v20"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
...
...
@@ -791,11 +1147,11 @@ inline bool write_to_output_c2_fp32(const float* din,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
)
{
float
*
trash_ptr
,
operators
::
ActivationParam
*
act_param
)
{
if
(
cs
>
channel
)
{
return
true
;
}
const
int
c2
=
2
;
const
int
w4
=
4
;
...
...
@@ -828,55 +1184,56 @@ inline bool write_to_output_c2_fp32(const float* din,
const
float
*
din_hei_ptr
=
ptr_din
+
i
*
w_round
*
c2
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
if
(
flag_relu
)
{
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_hei_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v20"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_hei_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
else
{
#ifdef __aarch64__
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_hei_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
);
#else
asm
volatile
(
NCHWC2_TRANS_FP32_COMPUTE
NCHWC2_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
ptr_din
]
"+r"
(
din_hei_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
act_switch_c2_fp32
(
din_hei_ptr
,
doutc0_ptr
,
doutc1_ptr
,
cnt_loop
,
act_param
);
}
if
(
we
>
width
)
{
int
offset
=
i
*
w_round
*
c2
+
c2
*
w4
*
cnt
;
din_hei_ptr
=
ptr_din
+
offset
;
doutc0_ptr
+=
w4
*
cnt
;
doutc1_ptr
+=
w4
*
cnt
;
int
j
=
we
-
w4
;
if
(
flag_relu
)
{
for
(;
j
<
width
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
din_hei_ptr
+=
2
;
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float
six
=
act_param
->
Relu_clipped_coef
;
float
scale
=
act_param
->
Leaky_relu_alpha
;
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
for
(;
j
<
width
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
din_hei_ptr
+=
2
;
}
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
for
(;
j
<
width
;
++
j
)
{
float
tmp1
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
float
tmp2
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
*
(
doutc0_ptr
++
)
=
LITEMIN
(
tmp1
,
six
);
*
(
doutc1_ptr
++
)
=
LITEMIN
(
tmp2
,
six
);
din_hei_ptr
+=
2
;
}
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
for
(;
j
<
width
;
++
j
)
{
if
(
din_hei_ptr
[
0
]
>=
0
)
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
];
}
else
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
]
*
scale
;
}
if
(
din_hei_ptr
[
1
]
>=
0
)
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
];
}
else
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
]
*
scale
;
}
din_hei_ptr
+=
2
;
}
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
for
(;
j
<
width
;
++
j
)
{
...
...
@@ -888,7 +1245,7 @@ inline bool write_to_output_c2_fp32(const float* din,
}
return
true
;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC4_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n"
/* load r00, r01 to q0, q1 */
\
...
...
@@ -912,6 +1269,26 @@ inline bool write_to_output_c2_fp32(const float* din,
"fmax v18.4s, v18.4s, v20.4s \n"
/*relu*/
\
"fmax v19.4s, v19.4s, v20.4s \n"
/*relu*/
#define NCHWC4_TRANS_FP32_RELU6 \
"fmin v16.4s, v16.4s, %[six].4s \n"
/* relu6 */
\
"fmin v17.4s, v17.4s, %[six].4s \n"
/* relu6 */
\
"fmin v18.4s, v18.4s, %[six].4s \n"
/* relu6 */
\
"fmin v19.4s, v19.4s, %[six].4s \n"
/* relu6 */
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"cmhs v8.4s, v16.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v9.4s, v17.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v10.4s, v18.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v11.4s, v19.4s, v20.4s \n"
/* vcgeq_u32 */
\
"fmul v4.4s, v16.4s, %[scale].4s \n"
/* mul */
\
"fmul v5.4s, v17.4s, %[scale].4s \n"
/* mul */
\
"fmul v6.4s, v18.4s, %[scale].4s \n"
/* mul */
\
"fmul v7.4s, v19.4s, %[scale].4s \n"
/* mul */
\
"bif v16.16b, v4.16b, v8.16b \n"
/* choose*/
\
"bif v17.16b, v5.16b, v9.16b \n"
/* choose*/
\
"bif v18.16b, v6.16b, v10.16b \n"
/* choose*/
\
"bif v19.16b, v7.16b, v11.16b \n"
/* choose*/
#define NCHWC4_TRANS_FP32_STORE \
"str q16, [%[doutc0r0]], #16 \n"
/* store c0r0*/
\
"str q17, [%[doutc2r0]], #16 \n"
/* store c2r0*/
\
...
...
@@ -940,6 +1317,26 @@ inline bool write_to_output_c2_fp32(const float* din,
"vmax.f32 q2, q2, q15 @ relu\n" \
"vmax.f32 q3, q3, q15 @ relu\n"
#define NCHWC4_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n" \
"vmin.f32 q2, q2, %q[six] @ relu6 \n" \
"vmin.f32 q3, q3, %q[six] @ relu6 \n"
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q7, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q8, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vmul.f32 q11, q2, %q[scale] \n" \
"vmul.f32 q12, q3, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n" \
"vbif q2, q11, q7 @ choose \n" \
"vbif q3, q12, q8 @ choose \n"
#define NCHWC4_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \
...
...
@@ -953,68 +1350,19 @@ inline bool write_to_output_c2_fp32(const float* din,
\
"bne 1b @ jump to main loop\n"
#endif
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
inline
bool
write_to_output_c4_fp32
(
const
float
*
din
,
float
*
dout
,
int
cs
,
int
ce
,
int
hs
,
int
he
,
int
ws
,
int
we
,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
)
{
const
int
c4
=
4
;
const
int
w4
=
4
;
const
int
w_round
=
we
-
ws
;
const
int
ch_n
=
ce
-
cs
;
if
(
ch_n
!=
4
)
{
LOG
(
ERROR
)
<<
"write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is "
"more than zero"
;
return
false
;
}
int
size_c_out
=
width
*
height
;
float
*
doutc0r0
=
dout
+
cs
*
size_c_out
+
hs
*
width
+
ws
;
float
*
doutc1r0
=
doutc0r0
+
size_c_out
;
float
*
doutc2r0
=
doutc1r0
+
size_c_out
;
float
*
doutc3r0
=
doutc2r0
+
size_c_out
;
const
float
*
ptr_din
=
din
;
int
size_h
=
(
he
>
height
?
height
:
he
)
-
hs
;
// size_h == hei_n
int
valid_we
=
we
>
width
?
width
:
we
;
int
cnt
=
(
valid_we
-
ws
)
/
w4
;
int
remain
=
valid_we
-
ws
-
cnt
*
w4
;
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
float
*
doutc1_ptr
=
doutc1r0
+
size_w
;
float
*
doutc2_ptr
=
doutc2r0
+
size_w
;
float
*
doutc3_ptr
=
doutc3r0
+
size_w
;
if
(
ce
>
channel
)
{
switch
(
ce
-
channel
)
{
case
3
:
doutc1_ptr
=
trash_ptr
;
case
2
:
doutc2_ptr
=
trash_ptr
;
case
1
:
doutc3_ptr
=
trash_ptr
;
default:
break
;
}
}
const
float
*
din_hei_ptr
=
ptr_din
+
i
*
w_round
*
ch_n
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
if
(
flag_relu
)
{
// clang-format on
inline
void
act_switch_c4_fp32
(
const
float
*
din_ptr
,
float
*
doutc0_ptr
,
float
*
doutc1_ptr
,
float
*
doutc2_ptr
,
float
*
doutc3_ptr
,
int
cnt_loop
,
const
operators
::
ActivationParam
*
act_param
)
{
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float32x4_t
six
=
vdupq_n_f32
(
act_param
->
Relu_clipped_coef
);
float32x4_t
scale
=
vdupq_n_f32
(
act_param
->
Leaky_relu_alpha
);
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_STORE
...
...
@@ -1023,7 +1371,7 @@ inline bool write_to_output_c4_fp32(const float* din,
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
)
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
...
...
@@ -1052,57 +1400,290 @@ inline bool write_to_output_c4_fp32(const float* din,
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
else
{
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_STORE
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
)
:
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
six
]
"w"
(
six
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
);
"v19"
,
"v20"
);
#else
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_STORE
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
);
:
[
six
]
"w"
(
six
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_LEAKY_RELU
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
scale
]
"w"
(
scale
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
#else
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_LEAKY_RELU
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
scale
]
"w"
(
scale
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q15"
);
#endif
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
);
#else
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q15"
);
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
inline
bool
write_to_output_c4_fp32
(
const
float
*
din
,
float
*
dout
,
int
cs
,
int
ce
,
int
hs
,
int
he
,
int
ws
,
int
we
,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
,
operators
::
ActivationParam
*
act_param
)
{
const
int
c4
=
4
;
const
int
w4
=
4
;
const
int
w_round
=
we
-
ws
;
const
int
ch_n
=
ce
-
cs
;
if
(
ch_n
!=
4
)
{
LOG
(
ERROR
)
<<
"write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is "
"more than zero"
;
return
false
;
}
int
size_c_out
=
width
*
height
;
float
*
doutc0r0
=
dout
+
cs
*
size_c_out
+
hs
*
width
+
ws
;
float
*
doutc1r0
=
doutc0r0
+
size_c_out
;
float
*
doutc2r0
=
doutc1r0
+
size_c_out
;
float
*
doutc3r0
=
doutc2r0
+
size_c_out
;
const
float
*
ptr_din
=
din
;
int
size_h
=
(
he
>
height
?
height
:
he
)
-
hs
;
// size_h == hei_n
int
valid_we
=
we
>
width
?
width
:
we
;
int
cnt
=
(
valid_we
-
ws
)
/
w4
;
int
remain
=
valid_we
-
ws
-
cnt
*
w4
;
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
float
*
doutc1_ptr
=
doutc1r0
+
size_w
;
float
*
doutc2_ptr
=
doutc2r0
+
size_w
;
float
*
doutc3_ptr
=
doutc3r0
+
size_w
;
if
(
ce
>
channel
)
{
switch
(
ce
-
channel
)
{
case
3
:
doutc1_ptr
=
trash_ptr
;
case
2
:
doutc2_ptr
=
trash_ptr
;
case
1
:
doutc3_ptr
=
trash_ptr
;
default:
break
;
}
}
const
float
*
din_hei_ptr
=
ptr_din
+
i
*
w_round
*
ch_n
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
act_switch_c4_fp32
(
din_hei_ptr
,
doutc0_ptr
,
doutc1_ptr
,
doutc2_ptr
,
doutc3_ptr
,
cnt_loop
,
act_param
);
}
if
(
remain
>
0
)
{
int
offset
=
i
*
w_round
*
c4
+
c4
*
w4
*
cnt
;
din_hei_ptr
=
ptr_din
+
offset
;
doutc0_ptr
+=
w4
*
cnt
;
doutc1_ptr
+=
w4
*
cnt
;
doutc2_ptr
+=
w4
*
cnt
;
doutc3_ptr
+=
w4
*
cnt
;
int
j
=
0
;
if
(
flag_relu
)
{
for
(;
j
<
remain
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
*
(
doutc2_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
*
(
doutc3_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
din_hei_ptr
+=
w4
;
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float
six
=
act_param
->
Relu_clipped_coef
;
float
scale
=
act_param
->
Leaky_relu_alpha
;
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
for
(;
j
<
remain
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
*
(
doutc2_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
*
(
doutc3_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
din_hei_ptr
+=
4
;
}
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
for
(;
j
<
remain
;
++
j
)
{
float
tmp1
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
float
tmp2
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
float
tmp3
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
float
tmp4
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
*
(
doutc0_ptr
++
)
=
LITEMIN
(
tmp1
,
six
);
*
(
doutc1_ptr
++
)
=
LITEMIN
(
tmp2
,
six
);
*
(
doutc2_ptr
++
)
=
LITEMIN
(
tmp3
,
six
);
*
(
doutc3_ptr
++
)
=
LITEMIN
(
tmp4
,
six
);
din_hei_ptr
+=
4
;
}
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
for
(;
j
<
remain
;
++
j
)
{
if
(
din_hei_ptr
[
0
]
>=
0
)
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
];
}
else
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
]
*
scale
;
}
if
(
din_hei_ptr
[
1
]
>=
0
)
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
];
}
else
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
]
*
scale
;
}
if
(
din_hei_ptr
[
2
]
>=
0
)
{
*
(
doutc2_ptr
++
)
=
din_hei_ptr
[
2
];
}
else
{
*
(
doutc2_ptr
++
)
=
din_hei_ptr
[
2
]
*
scale
;
}
if
(
din_hei_ptr
[
3
]
>=
0
)
{
*
(
doutc3_ptr
++
)
=
din_hei_ptr
[
3
];
}
else
{
*
(
doutc3_ptr
++
)
=
din_hei_ptr
[
3
]
*
scale
;
}
din_hei_ptr
+=
4
;
}
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
for
(;
j
<
remain
;
++
j
)
{
...
...
@@ -1110,14 +1691,14 @@ inline bool write_to_output_c4_fp32(const float* din,
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
];
*
(
doutc2_ptr
++
)
=
din_hei_ptr
[
2
];
*
(
doutc3_ptr
++
)
=
din_hei_ptr
[
3
];
din_hei_ptr
+=
w
4
;
din_hei_ptr
+=
4
;
}
}
}
}
return
true
;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC8_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n"
/* load r00, r01 to q0, q1 */
\
...
...
@@ -1161,6 +1742,48 @@ inline bool write_to_output_c4_fp32(const float* din,
"fmax v12.4s, v12.4s, v20.4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s, v20.4s \n"
/*relu*/
#define NCHWC8_TRANS_FP32_RELU6 \
"fmin v16.4s, v16.4s, %[six].4s \n"
/*relu6*/
\
"fmin v17.4s, v17.4s, %[six].4s \n"
/*relu6*/
\
"fmin v18.4s, v18.4s, %[six].4s \n"
/*relu6*/
\
"fmin v19.4s, v19.4s, %[six].4s \n"
/*relu6*/
\
\
"fmin v8.4s, v8.4s, %[six].4s \n"
/*relu6*/
\
"fmin v9.4s, v9.4s, %[six].4s \n"
/*relu6*/
\
"fmin v12.4s, v12.4s, %[six].4s \n"
/*relu6*/
\
"fmin v13.4s, v13.4s, %[six].4s \n"
/*relu6*/
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"cmhs v10.4s, v16.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v11.4s, v17.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v14.4s, v18.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v15.4s, v19.4s, v20.4s \n"
/* vcgeq_u32 */
\
\
"cmhs v21.4s, v8.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v22.4s, v9.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v23.4s, v12.4s, v20.4s \n"
/* vcgeq_u32 */
\
"cmhs v24.4s, v13.4s, v20.4s \n"
/* vcgeq_u32 */
\
\
"fmul v25.4s, v16.4s, %[scale].4s \n"
/* mul */
\
"fmul v26.4s, v17.4s, %[scale].4s \n"
/* mul */
\
"fmul v27.4s, v18.4s, %[scale].4s \n"
/* mul */
\
"fmul v28.4s, v19.4s, %[scale].4s \n"
/* mul */
\
\
"fmul v29.4s, v8.4s, %[scale].4s \n"
/* mul */
\
"fmul v30.4s, v9.4s, %[scale].4s \n"
/* mul */
\
"fmul v31.4s, v12.4s, %[scale].4s \n"
/* mul */
\
\
"bif v16.16b, v25.16b, v10.16b \n"
/* choose*/
\
"bif v17.16b, v26.16b, v11.16b \n"
/* choose*/
\
"bif v18.16b, v27.16b, v14.16b \n"
/* choose*/
\
"bif v19.16b, v28.16b, v15.16b \n"
/* choose*/
\
"fmul v25.4s, v13.4s, %[scale].4s \n"
/* mul */
\
\
"bif v8.16b, v29.16b, v21.16b \n"
/* choose*/
\
"bif v9.16b, v30.16b, v22.16b \n"
/* choose*/
\
"bif v12.16b, v31.16b, v23.16b \n"
/* choose*/
\
"bif v13.16b, v25.16b, v24.16b \n"
/* choose*/
#define NCHWC8_TRANS_FP32_STORE \
"str q16, [%[doutc0r0]], #16 \n"
/* store c0r0*/
\
"str q17, [%[doutc2r0]], #16 \n"
/* store c2r0*/
\
...
...
@@ -1174,6 +1797,7 @@ inline bool write_to_output_c4_fp32(const float* din,
"str q13, [%[doutc7r0]], #16 \n"
/* store c3r0*/
\
\
"bne 1b \n"
/* jump to main loop*/
#else
#define NCHWC8_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \
...
...
@@ -1203,6 +1827,48 @@ inline bool write_to_output_c4_fp32(const float* din,
"vmax.f32 q6, q6, q15 @ relu\n" \
"vmax.f32 q7, q7, q15 @ relu\n"
#define NCHWC8_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6\n" \
"vmin.f32 q1, q1, %q[six] @ relu6\n" \
"vmin.f32 q2, q2, %q[six] @ relu6\n" \
"vmin.f32 q3, q3, %q[six] @ relu6\n" \
\
"vmin.f32 q4, q4, %q[six] @ relu6\n" \
"vmin.f32 q5, q5, %q[six] @ relu6\n" \
"vmin.f32 q6, q6, %q[six] @ relu6\n" \
"vmin.f32 q7, q7, %q[six] @ relu6\n"
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q9, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q10, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q11, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q12, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q13, q0, %q[scale] \n" \
"vmul.f32 q14, q1, %q[scale] \n" \
"vmul.f32 q15, q2, %q[scale] \n" \
\
"vbif q0, q13, q9 @ choose \n" \
"vmul.f32 q9, q3, %q[scale] \n" \
\
"vbif q1, q14, q10 @ choose \n" \
"vbif q2, q15, q11 @ choose \n" \
"vbif q3, q9, q12 @ choose \n" \
\
"vcge.f32 q9, q4, q15 @ q0 > 0 \n" \
"vcge.f32 q10, q5, q15 @ q0 > 0 \n" \
"vcge.f32 q11, q6, q15 @ q0 > 0 \n" \
"vcge.f32 q12, q7, q15 @ q0 > 0 \n" \
"vmul.f32 q13, q4, %q[scale] \n" \
"vmul.f32 q14, q5, %q[scale] \n" \
"vmul.f32 q15, q6, %q[scale] \n" \
\
"vbif q4, q13, q9 @ choose \n" \
"vmul.f32 q9, q7, %q[scale] \n" \
\
"vbif q5, q14, q10 @ choose \n" \
"vbif q6, q15, q11 @ choose \n" \
"vbif q7, q9, q12 @ choose \n"
#define NCHWC8_TRANS_FP32_STORE \
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \
...
...
@@ -1232,84 +1898,23 @@ inline bool write_to_output_c4_fp32(const float* din,
"bne 1b @ jump to main loop\n"
#endif
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
inline
bool
write_to_output_c8_fp32
(
const
float
*
din
,
float
*
dout
,
int
ch_n
,
int
hei_n
,
int
cs
,
int
ce
,
int
hs
,
int
he
,
int
ws
,
int
we
,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
)
{
if
(
ch_n
!=
8
||
hei_n
<=
0
)
{
LOG
(
ERROR
)
<<
"ch_n must be equal 8 and hei_n is more than zero"
;
return
false
;
}
int
size_c_out
=
width
*
height
;
float
*
doutc0r0
=
dout
+
cs
*
size_c_out
+
hs
*
width
+
ws
;
float
*
doutc1r0
=
doutc0r0
+
size_c_out
;
float
*
doutc2r0
=
doutc1r0
+
size_c_out
;
float
*
doutc3r0
=
doutc2r0
+
size_c_out
;
float
*
doutc4r0
=
doutc3r0
+
size_c_out
;
float
*
doutc5r0
=
doutc4r0
+
size_c_out
;
float
*
doutc6r0
=
doutc5r0
+
size_c_out
;
float
*
doutc7r0
=
doutc6r0
+
size_c_out
;
const
float
*
ptr_din
=
din
;
int
size_h
=
(
he
>
height
?
height
:
he
)
-
hs
;
// size_h == hei_n
int
valid_w
=
we
-
ws
;
int
cnt
=
valid_w
/
4
;
if
(
we
>
width
)
{
cnt
--
;
}
if
(
flag_relu
)
{
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
float
*
doutc1_ptr
=
doutc1r0
+
size_w
;
float
*
doutc2_ptr
=
doutc2r0
+
size_w
;
float
*
doutc3_ptr
=
doutc3r0
+
size_w
;
float
*
doutc4_ptr
=
doutc4r0
+
size_w
;
float
*
doutc5_ptr
=
doutc5r0
+
size_w
;
float
*
doutc6_ptr
=
doutc6r0
+
size_w
;
float
*
doutc7_ptr
=
doutc7r0
+
size_w
;
if
(
ce
>
channel
)
{
switch
(
ce
-
channel
)
{
case
7
:
doutc1_ptr
=
trash_ptr
;
case
6
:
doutc2_ptr
=
trash_ptr
;
case
5
:
doutc3_ptr
=
trash_ptr
;
case
4
:
doutc4_ptr
=
trash_ptr
;
case
3
:
doutc5_ptr
=
trash_ptr
;
case
2
:
doutc6_ptr
=
trash_ptr
;
case
1
:
doutc7_ptr
=
trash_ptr
;
default:
break
;
}
}
ptr_din
=
din
+
i
*
valid_w
*
ch_n
;
const
float
*
din_hei_ptr
=
ptr_din
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
// clang-format on
inline
void
act_switch_c8_fp32
(
const
float
*
din_ptr
,
float
*
doutc0_ptr
,
float
*
doutc1_ptr
,
float
*
doutc2_ptr
,
float
*
doutc3_ptr
,
float
*
doutc4_ptr
,
float
*
doutc5_ptr
,
float
*
doutc6_ptr
,
float
*
doutc7_ptr
,
int
cnt_loop
,
const
operators
::
ActivationParam
*
act_param
)
{
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float32x4_t
six
=
vdupq_n_f32
(
act_param
->
Relu_clipped_coef
);
float32x4_t
scale
=
vdupq_n_f32
(
act_param
->
Leaky_relu_alpha
);
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
#ifdef __aarch64__
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_RELU
NCHWC8_TRANS_FP32_STORE
...
...
@@ -1322,9 +1927,10 @@ inline bool write_to_output_c8_fp32(const float* din,
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
)
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v1"
,
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
...
...
@@ -1338,7 +1944,6 @@ inline bool write_to_output_c8_fp32(const float* din,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
...
...
@@ -1355,66 +1960,17 @@ inline bool write_to_output_c8_fp32(const float* din,
[
doutc5r0
]
"+r"
(
doutc5_ptr
),
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q15"
);
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q
5"
,
"q6"
,
"q7"
,
"q
15"
);
#endif
}
if
(
we
>
width
)
{
int
offset
=
32
*
(
valid_w
/
4
-
1
);
din_hei_ptr
=
ptr_din
+
offset
;
int
i
=
we
-
4
;
for
(;
i
<
width
;
++
i
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
*
(
doutc2_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
*
(
doutc3_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
*
(
doutc4_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
4
],
0.
f
);
*
(
doutc5_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
5
],
0.
f
);
*
(
doutc6_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
6
],
0.
f
);
*
(
doutc7_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
7
],
0.
f
);
din_hei_ptr
+=
8
;
}
}
}
}
else
{
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
float
*
doutc1_ptr
=
doutc1r0
+
size_w
;
float
*
doutc2_ptr
=
doutc2r0
+
size_w
;
float
*
doutc3_ptr
=
doutc3r0
+
size_w
;
float
*
doutc4_ptr
=
doutc4r0
+
size_w
;
float
*
doutc5_ptr
=
doutc5r0
+
size_w
;
float
*
doutc6_ptr
=
doutc6r0
+
size_w
;
float
*
doutc7_ptr
=
doutc7r0
+
size_w
;
if
(
ce
>
channel
)
{
switch
(
ce
-
channel
)
{
case
7
:
doutc1_ptr
=
trash_ptr
;
case
6
:
doutc2_ptr
=
trash_ptr
;
case
5
:
doutc3_ptr
=
trash_ptr
;
case
4
:
doutc4_ptr
=
trash_ptr
;
case
3
:
doutc5_ptr
=
trash_ptr
;
case
2
:
doutc6_ptr
=
trash_ptr
;
case
1
:
doutc7_ptr
=
trash_ptr
;
default:
break
;
}
}
ptr_din
=
din
+
i
*
valid_w
*
ch_n
;
const
float
*
din_hei_ptr
=
ptr_din
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_STORE
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_RELU6
NCHWC8_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
...
...
@@ -1424,8 +1980,8 @@ inline bool write_to_output_c8_fp32(const float* din,
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_
hei_
ptr
)
:
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
six
]
"w"
(
six
)
:
"v0"
,
"v1"
,
"v2"
,
...
...
@@ -1441,14 +1997,29 @@ inline bool write_to_output_c8_fp32(const float* din,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
#else
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_STORE
asm
volatile
(
NCHWC4_TRANS_FP32_COMPUTE
NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6
NCHWC4_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
[
six
]
"w"
(
six
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q15"
);
#endif
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_LEAKY_RELU
NCHWC8_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
...
...
@@ -1457,16 +2028,323 @@ inline bool write_to_output_c8_fp32(const float* din,
[
doutc5r0
]
"+r"
(
doutc5_ptr
),
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
ptr_din
]
"+r"
(
din_hei_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
[
scale
]
"w"
(
scale
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
#else
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_LEAKY_RELU
NCHWC8_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
doutc4r0
]
"+r"
(
doutc4_ptr
),
[
doutc5r0
]
"+r"
(
doutc5_ptr
),
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
);
:
[
scale
]
"w"
(
scale
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
#endif
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
#ifdef __aarch64__
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
doutc4r0
]
"+r"
(
doutc4_ptr
),
[
doutc5r0
]
"+r"
(
doutc5_ptr
),
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
cnt
]
"+r"
(
cnt_loop
),
[
ptr_din
]
"+r"
(
din_ptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
#else
asm
volatile
(
NCHWC8_TRANS_FP32_COMPUTE
NCHWC8_TRANS_FP32_STORE
:
[
doutc0r0
]
"+r"
(
doutc0_ptr
),
[
doutc1r0
]
"+r"
(
doutc1_ptr
),
[
doutc2r0
]
"+r"
(
doutc2_ptr
),
[
doutc3r0
]
"+r"
(
doutc3_ptr
),
[
doutc4r0
]
"+r"
(
doutc4_ptr
),
[
doutc5r0
]
"+r"
(
doutc5_ptr
),
[
doutc6r0
]
"+r"
(
doutc6_ptr
),
[
doutc7r0
]
"+r"
(
doutc7_ptr
),
[
ptr_din
]
"+r"
(
din_ptr
),
[
cnt
]
"+r"
(
cnt_loop
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q15"
);
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
inline
bool
write_to_output_c8_fp32
(
const
float
*
din
,
float
*
dout
,
int
ch_n
,
int
hei_n
,
int
cs
,
int
ce
,
int
hs
,
int
he
,
int
ws
,
int
we
,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
float
*
trash_ptr
,
operators
::
ActivationParam
*
act_param
)
{
if
(
ch_n
!=
8
||
hei_n
<=
0
)
{
LOG
(
ERROR
)
<<
"ch_n must be equal 8 and hei_n is more than zero"
;
return
false
;
}
int
size_c_out
=
width
*
height
;
float
*
doutc0r0
=
dout
+
cs
*
size_c_out
+
hs
*
width
+
ws
;
float
*
doutc1r0
=
doutc0r0
+
size_c_out
;
float
*
doutc2r0
=
doutc1r0
+
size_c_out
;
float
*
doutc3r0
=
doutc2r0
+
size_c_out
;
float
*
doutc4r0
=
doutc3r0
+
size_c_out
;
float
*
doutc5r0
=
doutc4r0
+
size_c_out
;
float
*
doutc6r0
=
doutc5r0
+
size_c_out
;
float
*
doutc7r0
=
doutc6r0
+
size_c_out
;
const
float
*
ptr_din
=
din
;
int
size_h
=
(
he
>
height
?
height
:
he
)
-
hs
;
// size_h == hei_n
int
valid_w
=
we
-
ws
;
int
w4
=
4
;
int
cnt
=
valid_w
/
4
;
if
(
we
>
width
)
{
cnt
--
;
}
for
(
int
i
=
0
;
i
<
size_h
;
i
++
)
{
int
size_w
=
i
*
width
;
float
*
doutc0_ptr
=
doutc0r0
+
size_w
;
// doutc0r0 + width;
float
*
doutc1_ptr
=
doutc1r0
+
size_w
;
float
*
doutc2_ptr
=
doutc2r0
+
size_w
;
float
*
doutc3_ptr
=
doutc3r0
+
size_w
;
float
*
doutc4_ptr
=
doutc4r0
+
size_w
;
float
*
doutc5_ptr
=
doutc5r0
+
size_w
;
float
*
doutc6_ptr
=
doutc6r0
+
size_w
;
float
*
doutc7_ptr
=
doutc7r0
+
size_w
;
if
(
ce
>
channel
)
{
switch
(
ce
-
channel
)
{
case
7
:
doutc1_ptr
=
trash_ptr
;
case
6
:
doutc2_ptr
=
trash_ptr
;
case
5
:
doutc3_ptr
=
trash_ptr
;
case
4
:
doutc4_ptr
=
trash_ptr
;
case
3
:
doutc5_ptr
=
trash_ptr
;
case
2
:
doutc6_ptr
=
trash_ptr
;
case
1
:
doutc7_ptr
=
trash_ptr
;
default:
break
;
}
if
(
we
>
width
)
{
int
offset
=
32
*
(
valid_w
/
4
-
1
);
din_hei_ptr
=
ptr_din
+
offset
;
int
i
=
we
-
4
;
}
ptr_din
=
din
+
i
*
valid_w
*
ch_n
;
const
float
*
din_hei_ptr
=
ptr_din
;
if
(
cnt
>
0
)
{
int
cnt_loop
=
cnt
;
act_switch_c8_fp32
(
din_hei_ptr
,
doutc0_ptr
,
doutc1_ptr
,
doutc2_ptr
,
doutc3_ptr
,
doutc4_ptr
,
doutc5_ptr
,
doutc6_ptr
,
doutc7_ptr
,
cnt_loop
,
act_param
);
}
if
(
we
>
width
)
{
int
offset
=
32
*
(
valid_w
/
4
-
1
);
din_hei_ptr
=
ptr_din
+
offset
;
doutc0_ptr
+=
w4
*
cnt
;
doutc1_ptr
+=
w4
*
cnt
;
doutc2_ptr
+=
w4
*
cnt
;
doutc3_ptr
+=
w4
*
cnt
;
doutc4_ptr
+=
w4
*
cnt
;
doutc5_ptr
+=
w4
*
cnt
;
doutc6_ptr
+=
w4
*
cnt
;
doutc7_ptr
+=
w4
*
cnt
;
int
i
=
we
-
4
;
if
(
act_param
!=
nullptr
&&
act_param
->
has_active
)
{
float
six
=
act_param
->
Relu_clipped_coef
;
float
scale
=
act_param
->
Leaky_relu_alpha
;
switch
(
act_param
->
active_type
)
{
case
lite_api
::
ActivationType
::
kRelu
:
for
(;
i
<
width
;
++
i
)
{
*
(
doutc0_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
*
(
doutc1_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
*
(
doutc2_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
*
(
doutc3_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
*
(
doutc4_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
4
],
0.
f
);
*
(
doutc5_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
5
],
0.
f
);
*
(
doutc6_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
6
],
0.
f
);
*
(
doutc7_ptr
++
)
=
LITEMAX
(
din_hei_ptr
[
7
],
0.
f
);
din_hei_ptr
+=
8
;
}
break
;
case
lite_api
::
ActivationType
::
kRelu6
:
/* 0 <= din <= 6 */
for
(;
i
<
width
;
++
i
)
{
float
tmp1
=
LITEMAX
(
din_hei_ptr
[
0
],
0.
f
);
float
tmp2
=
LITEMAX
(
din_hei_ptr
[
1
],
0.
f
);
float
tmp3
=
LITEMAX
(
din_hei_ptr
[
2
],
0.
f
);
float
tmp4
=
LITEMAX
(
din_hei_ptr
[
3
],
0.
f
);
float
tmp5
=
LITEMAX
(
din_hei_ptr
[
4
],
0.
f
);
float
tmp6
=
LITEMAX
(
din_hei_ptr
[
5
],
0.
f
);
float
tmp7
=
LITEMAX
(
din_hei_ptr
[
6
],
0.
f
);
float
tmp8
=
LITEMAX
(
din_hei_ptr
[
7
],
0.
f
);
*
(
doutc0_ptr
++
)
=
LITEMIN
(
tmp1
,
six
);
*
(
doutc1_ptr
++
)
=
LITEMIN
(
tmp2
,
six
);
*
(
doutc2_ptr
++
)
=
LITEMIN
(
tmp3
,
six
);
*
(
doutc3_ptr
++
)
=
LITEMIN
(
tmp4
,
six
);
*
(
doutc4_ptr
++
)
=
LITEMIN
(
tmp5
,
six
);
*
(
doutc5_ptr
++
)
=
LITEMIN
(
tmp6
,
six
);
*
(
doutc6_ptr
++
)
=
LITEMIN
(
tmp7
,
six
);
*
(
doutc7_ptr
++
)
=
LITEMIN
(
tmp8
,
six
);
din_hei_ptr
+=
8
;
}
break
;
case
lite_api
::
ActivationType
::
kLeakyRelu
:
/*din = din >= 0 ? din : din * scale*/
for
(;
i
<
width
;
++
i
)
{
if
(
din_hei_ptr
[
0
]
>=
0
)
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
];
}
else
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
]
*
scale
;
}
if
(
din_hei_ptr
[
1
]
>=
0
)
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
];
}
else
{
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
]
*
scale
;
}
if
(
din_hei_ptr
[
2
]
>=
0
)
{
*
(
doutc2_ptr
++
)
=
din_hei_ptr
[
2
];
}
else
{
*
(
doutc2_ptr
++
)
=
din_hei_ptr
[
2
]
*
scale
;
}
if
(
din_hei_ptr
[
3
]
>=
0
)
{
*
(
doutc3_ptr
++
)
=
din_hei_ptr
[
3
];
}
else
{
*
(
doutc3_ptr
++
)
=
din_hei_ptr
[
3
]
*
scale
;
}
if
(
din_hei_ptr
[
4
]
>=
0
)
{
*
(
doutc4_ptr
++
)
=
din_hei_ptr
[
4
];
}
else
{
*
(
doutc4_ptr
++
)
=
din_hei_ptr
[
4
]
*
scale
;
}
if
(
din_hei_ptr
[
4
]
>=
0
)
{
*
(
doutc5_ptr
++
)
=
din_hei_ptr
[
5
];
}
else
{
*
(
doutc5_ptr
++
)
=
din_hei_ptr
[
5
]
*
scale
;
}
if
(
din_hei_ptr
[
6
]
>=
0
)
{
*
(
doutc6_ptr
++
)
=
din_hei_ptr
[
6
];
}
else
{
*
(
doutc6_ptr
++
)
=
din_hei_ptr
[
6
]
*
scale
;
}
if
(
din_hei_ptr
[
7
]
>=
0
)
{
*
(
doutc7_ptr
++
)
=
din_hei_ptr
[
7
];
}
else
{
*
(
doutc7_ptr
++
)
=
din_hei_ptr
[
7
]
*
scale
;
}
din_hei_ptr
+=
8
;
}
break
;
default:
LOG
(
FATAL
)
<<
"this act_type: "
<<
static_cast
<
int
>
(
act_param
->
active_type
)
<<
" fuse not support"
;
}
}
else
{
for
(;
i
<
width
;
++
i
)
{
*
(
doutc0_ptr
++
)
=
din_hei_ptr
[
0
];
*
(
doutc1_ptr
++
)
=
din_hei_ptr
[
1
];
...
...
lite/backends/arm/math/conv_depthwise.h
浏览文件 @
3455ab0a
...
...
@@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const
float
*
weights
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_3x3s2_depthwise_fp32
(
const
float
*
i_data
,
...
...
@@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din,
int
pad
,
bool
flag_bias
,
bool
flag_relu
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
);
void
conv_depthwise_3x3s2_fp32
(
const
float
*
din
,
...
...
lite/backends/arm/math/conv_impl.cc
浏览文件 @
3455ab0a
...
...
@@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din,
ARMContext
*
ctx
,
const
float
*
scale
)
{
auto
paddings
=
*
param
.
paddings
;
auto
act_param
=
param
.
activation_param
;
const
int
pad_h
=
paddings
[
0
];
const
int
pad_w
=
paddings
[
2
];
int
stride
=
param
.
strides
[
1
];
...
...
@@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din,
pad
,
flag_bias
,
flag_relu
,
act_param
,
ctx
);
}
else
{
conv_3x3s1_depthwise_fp32
(
reinterpret_cast
<
const
float
*>
(
din
),
...
...
@@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast
<
const
float
*>
(
weights
),
bias
,
param
,
act_param
,
ctx
);
}
...
...
lite/kernels/arm/conv_compute.cc
浏览文件 @
3455ab0a
...
...
@@ -67,7 +67,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_
=
new
DepthwiseConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking dw conv"
;
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
1
&&
kps_equal
&&
no_dilation
)
{
no_dilation
&&
pads_all_equal
)
{
/// winograd conv impl
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking winograd conv"
;
...
...
lite/operators/conv_op.cc
浏览文件 @
3455ab0a
...
...
@@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size,
return
output_size
;
}
inline
void
UpdatePaddingAndDilation
(
std
::
vector
<
int
>*
paddings
,
std
::
vector
<
int
>*
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
string
padding_algorithm
,
const
lite
::
DDim
data_dims
,
const
lite
::
DDim
&
ksize
)
{
// when padding_desc is "VALID" or "SAME"
if
(
padding_algorithm
==
"SAME"
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
int
out_size
=
(
data_dims
[
i
+
2
]
+
strides
[
i
]
-
1
)
/
strides
[
i
];
int
pad_sum
=
std
::
max
(
(
out_size
-
1
)
*
strides
[
i
]
+
ksize
[
i
+
2
]
-
data_dims
[
i
+
2
],
(
int64_t
)
0
);
int
pad_0
=
pad_sum
/
2
;
int
pad_1
=
pad_sum
-
pad_0
;
// pad
*
(
paddings
->
begin
()
+
i
*
2
)
=
pad_0
;
*
(
paddings
->
begin
()
+
i
*
2
+
1
)
=
pad_1
;
// dilation
*
(
dilations
->
begin
()
+
i
)
=
1
;
}
}
else
if
(
padding_algorithm
==
"VALID"
)
{
for
(
auto
&
it
:
*
paddings
)
{
it
=
0
;
}
}
}
bool
ConvOpLite
::
InferShape
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
...
...
lite/operators/conv_op.h
浏览文件 @
3455ab0a
...
...
@@ -137,34 +137,6 @@ class ConvOpLite : public OpLite {
std
::
string
padding_algorithm_
{
""
};
};
inline
void
UpdatePaddingAndDilation
(
std
::
vector
<
int
>*
paddings
,
std
::
vector
<
int
>*
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
string
padding_algorithm
,
const
lite
::
DDim
data_dims
,
const
lite
::
DDim
&
ksize
)
{
// when padding_desc is "VALID" or "SAME"
if
(
padding_algorithm
==
"SAME"
)
{
for
(
size_t
i
=
0
;
i
<
strides
.
size
();
++
i
)
{
int
out_size
=
(
data_dims
[
i
+
2
]
+
strides
[
i
]
-
1
)
/
strides
[
i
];
int
pad_sum
=
std
::
max
(
(
out_size
-
1
)
*
strides
[
i
]
+
ksize
[
i
+
2
]
-
data_dims
[
i
+
2
],
(
int64_t
)
0
);
int
pad_0
=
pad_sum
/
2
;
int
pad_1
=
pad_sum
-
pad_0
;
// pad
*
(
paddings
->
begin
()
+
i
*
2
)
=
pad_0
;
*
(
paddings
->
begin
()
+
i
*
2
+
1
)
=
pad_1
;
// dilation
*
(
dilations
->
begin
()
+
i
)
=
1
;
}
}
else
if
(
padding_algorithm
==
"VALID"
)
{
for
(
auto
&
it
:
*
paddings
)
{
it
=
0
;
}
}
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
lite/tests/math/conv_compute_test.cc
浏览文件 @
3455ab0a
...
...
@@ -59,6 +59,8 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef
paddle
::
lite
::
DDim
DDim
;
typedef
paddle
::
lite
::
Tensor
Tensor
;
typedef
paddle
::
lite
::
operators
::
ConvParam
ConvParam
;
typedef
paddle
::
lite
::
operators
::
ActivationParam
ActivationParam
;
using
paddle
::
lite
::
profile
::
Timer
;
DDim
compute_out_dim
(
const
DDim
&
dim_in
,
...
...
@@ -118,6 +120,13 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
param
.
dilations
=
std
::
make_shared
<
std
::
vector
<
int
>>
(
dilas
);
param
.
fuse_relu
=
flag_relu
;
param
.
groups
=
group
;
if
(
flag_relu
)
{
ActivationParam
act_param
;
act_param
.
has_active
=
true
;
act_param
.
active_type
=
(
paddle
::
lite_api
::
ActivationType
)
1
;
// 2-relu6 4-leakyrelu
param
.
activation_param
=
act_param
;
}
param
.
output
=
new
Tensor
;
param
.
output
->
set_precision
(
PRECISION
(
kFloat
));
...
...
@@ -243,6 +252,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<<
pads
[
2
]
<<
", "
<<
pads
[
3
]
<<
", stride: "
<<
strides
[
0
]
<<
", "
<<
strides
[
1
]
<<
", dila_: "
<<
dilas
[
0
]
<<
", "
<<
dilas
[
1
]
<<
", group: "
<<
group
<<
", bias: "
<<
(
flag_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
flag_relu
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
", power_mode: "
<<
cls
...
...
@@ -255,6 +265,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<<
", pad: "
<<
pads
[
0
]
<<
", "
<<
pads
[
1
]
<<
", stride: "
<<
strides
[
0
]
<<
", "
<<
strides
[
1
]
<<
", dila_: "
<<
dilas
[
0
]
<<
", "
<<
dilas
[
1
]
<<
", group: "
<<
group
<<
", bias: "
<<
(
flag_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
flag_relu
?
"true"
:
"false"
)
<<
", threads: "
<<
th
<<
", power_mode: "
<<
cls
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录