Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
74a309cb
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
74a309cb
编写于
7月 17, 2019
作者:
S
StarryRain
提交者:
Yanzhan Yang
7月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CPU_ARCH info, improve the performance of GEMM1*1s1 (#1751)
上级
497bf326
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
2420 addition
and
4 deletion
+2420
-4
src/common/types.h
src/common/types.h
+12
-0
src/framework/context.cpp
src/framework/context.cpp
+45
-4
src/framework/context.h
src/framework/context.h
+2
-0
src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
...rators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
...operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
...rators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_common.cpp
src/operators/kernel/arm/convolution/conv_common.cpp
+20
-0
src/operators/kernel/arm/convolution/conv_kernel.cpp
src/operators/kernel/arm/convolution/conv_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/conv_relu_kernel.cpp
src/operators/kernel/arm/convolution/conv_relu_kernel.cpp
+3
-0
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
...perators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
+3
-0
src/operators/kernel/central-arm-func/conv_arm_func.cpp
src/operators/kernel/central-arm-func/conv_arm_func.cpp
+58
-0
src/operators/kernel/central-arm-func/conv_arm_func.h
src/operators/kernel/central-arm-func/conv_arm_func.h
+3
-0
src/operators/math/gemm/gemm1x1s1.cpp
src/operators/math/gemm/gemm1x1s1.cpp
+2198
-0
src/operators/math/gemm/gemm1x1s1.h
src/operators/math/gemm/gemm1x1s1.h
+57
-0
src/operators/op_param.h
src/operators/op_param.h
+1
-0
未找到文件。
src/common/types.h
浏览文件 @
74a309cb
...
@@ -145,6 +145,18 @@ struct PaddleMobileConfigInternal {
...
@@ -145,6 +145,18 @@ struct PaddleMobileConfigInternal {
std
::
string
model_obfuscate_key
=
""
;
std
::
string
model_obfuscate_key
=
""
;
};
};
enum
ARMArch
{
APPLE
=
0
,
A53
=
53
,
A55
=
55
,
A57
=
57
,
A72
=
72
,
A73
=
73
,
A75
=
75
,
A76
=
76
,
ARM_UNKOWN
=
-
1
};
extern
const
char
*
G_OP_TYPE_CONV
;
extern
const
char
*
G_OP_TYPE_CONV
;
extern
const
char
*
G_OP_TYPE_BATCHNORM
;
extern
const
char
*
G_OP_TYPE_BATCHNORM
;
extern
const
char
*
G_OP_TYPE_BOX_CODER
;
extern
const
char
*
G_OP_TYPE_BOX_CODER
;
...
...
src/framework/context.cpp
浏览文件 @
74a309cb
...
@@ -261,7 +261,8 @@ int set_sched_affinity(const std::vector<int> &cpu_ids) {
...
@@ -261,7 +261,8 @@ int set_sched_affinity(const std::vector<int> &cpu_ids) {
return
0
;
return
0
;
}
}
int
get_cpu_info_by_name
(
int
*
cpu_num
,
std
::
vector
<
int
>
*
big_core_ids
,
int
get_cpu_info_by_name
(
int
*
cpu_num
,
ARMArch
*
arch
,
std
::
vector
<
int
>
*
big_core_ids
,
std
::
vector
<
int
>
*
little_core_ids
,
std
::
vector
<
int
>
*
little_core_ids
,
std
::
vector
<
int
>
*
l1_cache_sizes
,
std
::
vector
<
int
>
*
l1_cache_sizes
,
std
::
vector
<
int
>
*
l2_cache_sizes
,
std
::
vector
<
int
>
*
l2_cache_sizes
,
...
@@ -270,6 +271,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
...
@@ -270,6 +271,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
/* Snapdragon */
/* Snapdragon */
if
(
hardware_name
.
find
(
"SDM845"
)
!=
std
::
string
::
npos
)
{
// 845
if
(
hardware_name
.
find
(
"SDM845"
)
!=
std
::
string
::
npos
)
{
// 845
*
cpu_num
=
8
;
*
cpu_num
=
8
;
*
arch
=
A75
;
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
l1_cache_sizes
->
resize
(
*
cpu_num
);
l1_cache_sizes
->
resize
(
*
cpu_num
);
...
@@ -282,6 +284,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
...
@@ -282,6 +284,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
return
0
;
return
0
;
}
else
if
(
hardware_name
.
find
(
"SDM710"
)
!=
std
::
string
::
npos
)
{
// 710
}
else
if
(
hardware_name
.
find
(
"SDM710"
)
!=
std
::
string
::
npos
)
{
// 710
*
cpu_num
=
8
;
*
cpu_num
=
8
;
*
arch
=
A75
;
*
big_core_ids
=
{
6
,
7
};
*
big_core_ids
=
{
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
};
l1_cache_sizes
->
resize
(
*
cpu_num
);
l1_cache_sizes
->
resize
(
*
cpu_num
);
...
@@ -295,6 +298,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
...
@@ -295,6 +298,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
return
0
;
return
0
;
}
else
if
(
hardware_name
.
find
(
"MSM8998"
)
!=
std
::
string
::
npos
)
{
// 835
}
else
if
(
hardware_name
.
find
(
"MSM8998"
)
!=
std
::
string
::
npos
)
{
// 835
*
cpu_num
=
8
;
*
cpu_num
=
8
;
*
arch
=
A73
;
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
l1_cache_sizes
->
resize
(
*
cpu_num
);
l1_cache_sizes
->
resize
(
*
cpu_num
);
...
@@ -313,8 +317,9 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
...
@@ -313,8 +317,9 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
return
0
;
return
0
;
}
else
if
(
hardware_name
.
find
(
"MSM8976"
)
!=
std
::
string
::
npos
)
{
// 652,653
}
else
if
(
hardware_name
.
find
(
"MSM8976"
)
!=
std
::
string
::
npos
)
{
// 652,653
*
cpu_num
=
8
;
*
cpu_num
=
8
;
*
big_core_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
*
arch
=
A72
;
*
little_core_ids
=
{};
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
l1_cache_sizes
->
resize
(
*
cpu_num
);
l1_cache_sizes
->
resize
(
*
cpu_num
);
l2_cache_sizes
->
resize
(
*
cpu_num
);
l2_cache_sizes
->
resize
(
*
cpu_num
);
l3_cache_sizes
->
resize
(
*
cpu_num
);
l3_cache_sizes
->
resize
(
*
cpu_num
);
...
@@ -322,6 +327,42 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
...
@@ -322,6 +327,42 @@ int get_cpu_info_by_name(int *cpu_num, std::vector<int> *big_core_ids,
fill_cpu_cache_size
(
l2_cache_sizes
,
1024
*
1024
);
fill_cpu_cache_size
(
l2_cache_sizes
,
1024
*
1024
);
fill_cpu_cache_size
(
l3_cache_sizes
,
0
);
fill_cpu_cache_size
(
l3_cache_sizes
,
0
);
return
0
;
return
0
;
}
else
if
(
hardware_name
.
find
(
"SDM660"
)
!=
std
::
string
::
npos
||
hardware_name
.
find
(
"SDM636"
)
!=
std
::
string
::
npos
)
{
// 660, 636
*
cpu_num
=
8
;
*
arch
=
A73
;
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
l1_cache_sizes
->
resize
(
*
cpu_num
);
l2_cache_sizes
->
resize
(
*
cpu_num
);
l3_cache_sizes
->
resize
(
*
cpu_num
);
fill_cpu_cache_size
(
l1_cache_sizes
,
64
*
1024
);
fill_cpu_cache_size
(
l2_cache_sizes
,
1024
*
1024
);
fill_cpu_cache_size
(
l3_cache_sizes
,
0
);
return
0
;
/* MediaTek */
}
else
if
(
hardware_name
.
find
(
"MT6799"
)
!=
std
::
string
::
npos
)
{
// X30
*
cpu_num
=
10
;
*
arch
=
A73
;
*
big_core_ids
=
{
8
,
9
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
return
0
;
}
else
if
(
hardware_name
.
find
(
"MT6771"
)
!=
std
::
string
::
npos
)
{
// P60
*
cpu_num
=
8
;
*
arch
=
A73
;
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
return
0
;
/* Kirin */
}
else
if
(
hardware_name
.
find
(
"KIRIN970"
)
!=
std
::
string
::
npos
)
{
// Kirin 970
*
cpu_num
=
8
;
*
arch
=
A73
;
*
big_core_ids
=
{
4
,
5
,
6
,
7
};
*
little_core_ids
=
{
0
,
1
,
2
,
3
};
return
0
;
}
}
return
-
1
;
return
-
1
;
}
}
...
@@ -410,7 +451,7 @@ CPUContext::CPUContext() {
...
@@ -410,7 +451,7 @@ CPUContext::CPUContext() {
// probe cpu info, and set big&litte clusters, L1, L2 and L3 cache sizes
// probe cpu info, and set big&litte clusters, L1, L2 and L3 cache sizes
std
::
string
cpu_name
=
get_cpu_name
();
std
::
string
cpu_name
=
get_cpu_name
();
bool
failed
=
bool
failed
=
get_cpu_info_by_name
(
&
_cpu_num
,
&
_big_core_ids
,
&
_little_core_ids
,
get_cpu_info_by_name
(
&
_cpu_num
,
&
_
arch
,
&
_
big_core_ids
,
&
_little_core_ids
,
&
_l1_cache_sizes
,
&
_l2_cache_sizes
,
&
_l3_cache_sizes
,
&
_l1_cache_sizes
,
&
_l2_cache_sizes
,
&
_l3_cache_sizes
,
cpu_name
)
!=
0
;
cpu_name
)
!=
0
;
if
(
failed
)
{
if
(
failed
)
{
...
...
src/framework/context.h
浏览文件 @
74a309cb
...
@@ -43,12 +43,14 @@ struct CPUContext {
...
@@ -43,12 +43,14 @@ struct CPUContext {
int
get_thread_num
();
int
get_thread_num
();
PowerMode
get_power_mode
()
const
{
return
_power_mode
;
}
PowerMode
get_power_mode
()
const
{
return
_power_mode
;
}
int
get_cache_size
(
int
level
);
int
get_cache_size
(
int
level
);
ARMArch
get_arch
()
const
{
return
_arch
;
}
int
get_l1_cache_size
()
{
return
get_cache_size
(
1
);
}
int
get_l1_cache_size
()
{
return
get_cache_size
(
1
);
}
int
get_l2_cache_size
()
{
return
get_cache_size
(
2
);
}
int
get_l2_cache_size
()
{
return
get_cache_size
(
2
);
}
int
get_l3_cache_size
()
{
return
get_cache_size
(
3
);
}
int
get_l3_cache_size
()
{
return
get_cache_size
(
3
);
}
void
*
get_work_space
(
int
size_in_byte
);
void
*
get_work_space
(
int
size_in_byte
);
int
_cpu_num
;
int
_cpu_num
;
ARMArch
_arch
;
PowerMode
_power_mode
;
PowerMode
_power_mode
;
std
::
vector
<
int
>
_big_core_ids
;
std
::
vector
<
int
>
_big_core_ids
;
std
::
vector
<
int
>
_little_core_ids
;
std
::
vector
<
int
>
_little_core_ids
;
...
...
src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -126,6 +126,9 @@ void ConvAddBNReluKernel<CPU, float>::Compute(
...
@@ -126,6 +126,9 @@ void ConvAddBNReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_add_kernel.cpp
浏览文件 @
74a309cb
...
@@ -44,6 +44,9 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> ¶m) {
...
@@ -44,6 +44,9 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> ¶m) {
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -45,6 +45,9 @@ void ConvAddReluKernel<CPU, float>::Compute(
...
@@ -45,6 +45,9 @@ void ConvAddReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -64,6 +64,9 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
...
@@ -64,6 +64,9 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -77,6 +77,9 @@ void ConvBNReluKernel<CPU, float>::Compute(
...
@@ -77,6 +77,9 @@ void ConvBNReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_common.cpp
浏览文件 @
74a309cb
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/arm/convolution/conv_common.h"
#include "framework/context.h"
#include "operators/math/gemm/gemm1x1s1.h"
#include "operators/math/slidingwindow_utils.h"
#include "operators/math/slidingwindow_utils.h"
#include "operators/math/winograd/winograd_transform.h"
#include "operators/math/winograd/winograd_transform.h"
...
@@ -20,6 +22,8 @@ namespace paddle_mobile {
...
@@ -20,6 +22,8 @@ namespace paddle_mobile {
namespace
operators
{
namespace
operators
{
void
InitBaseConvKernel
(
ConvParam
<
CPU
>
*
param
)
{
void
InitBaseConvKernel
(
ConvParam
<
CPU
>
*
param
)
{
bool
conv1x1
=
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
1
;
bool
conv3x3
=
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
bool
conv3x3
=
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
3
;
param
->
Filter
()
->
dims
()[
2
]
==
3
;
bool
conv5x5
=
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
bool
conv5x5
=
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
...
@@ -83,6 +87,22 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
...
@@ -83,6 +87,22 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
math
::
slidingwindow_transform_weight
<
float
>
(
*
param
->
Filter
(),
math
::
slidingwindow_transform_weight
<
float
>
(
*
param
->
Filter
(),
param
->
transformed_filter_
);
param
->
transformed_filter_
);
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
;
}
else
if
(
conv1x1
&&
param
->
Groups
()
==
1
&&
param
->
Paddings
()[
0
]
==
param
->
Paddings
()[
1
]
&&
param
->
Paddings
()[
0
]
==
0
&&
param
->
Input
()
->
dims
()[
1
]
>
1
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Dilations
()[
0
]
==
param
->
Dilations
()[
1
]
&&
param
->
Strides
()[
0
]
==
1
&&
param
->
Dilations
()[
0
]
==
1
&&
param
->
Output
()
->
dims
()[
2
]
*
param
->
Output
()
->
dims
()[
3
]
>
1
)
{
// transform weight
Variable
*
transformed_var
=
param
->
GetScope
()
->
Var
();
ARMArch
arch
=
framework
::
CPUContext
::
Context
()
->
get_arch
();
param
->
transformed_filter_
=
transformed_var
->
GetMutable
<
framework
::
LoDTensor
>
();
math
::
gemm1x1s1_transform_weight
(
*
param
->
Filter
(),
*
param
->
Output
(),
param
->
transformed_filter_
,
param
->
groups
,
arch
);
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
;
}
else
{
}
else
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
;
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
;
}
}
...
...
src/operators/kernel/arm/convolution/conv_kernel.cpp
浏览文件 @
74a309cb
...
@@ -54,6 +54,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> ¶m) {
...
@@ -54,6 +54,9 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> ¶m) {
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/conv_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -45,6 +45,9 @@ void ConvReluKernel<CPU, float>::Compute(
...
@@ -45,6 +45,9 @@ void ConvReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S1_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_SLIDINGWINDOW3x3S2_FLOAT
:
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
SlidingwindowConv3x3
<
float
,
float
>
(
param
);
...
...
src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp
浏览文件 @
74a309cb
...
@@ -76,6 +76,9 @@ void DWConvBNReluKernel<CPU, float>::Compute(
...
@@ -76,6 +76,9 @@ void DWConvBNReluKernel<CPU, float>::Compute(
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
break
;
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM1x1s1_FLOAT
:
GemmConv1x1s1
<
float
,
float
>
(
param
);
break
;
default:
default:
PADDLE_MOBILE_THROW_EXCEPTION
(
"Invalid convolution execute mode %d"
,
PADDLE_MOBILE_THROW_EXCEPTION
(
"Invalid convolution execute mode %d"
,
param
.
ExecMode
());
param
.
ExecMode
());
...
...
src/operators/kernel/central-arm-func/conv_arm_func.cpp
浏览文件 @
74a309cb
...
@@ -14,9 +14,11 @@ limitations under the License. */
...
@@ -14,9 +14,11 @@ limitations under the License. */
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <vector>
#include <vector>
#include "framework/context.h"
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
#include "operators/math/depthwise/faster_depthwise_conv3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/gemm/gemm1x1s1.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
#include "operators/math/pad.h"
...
@@ -137,6 +139,61 @@ void GemmConv(const ConvParam<CPU> ¶m) {
...
@@ -137,6 +139,61 @@ void GemmConv(const ConvParam<CPU> ¶m) {
}
}
}
}
template
<
typename
Itype
,
typename
Otype
>
void
GemmConv1x1s1
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
Tensor
filter
=
*
param
.
transformed_filter_
;
Tensor
*
output
=
param
.
Output
();
output
->
mutable_data
<
Otype
>
();
const
float
*
din
=
input
->
data
<
Itype
>
();
float
*
dout
=
output
->
mutable_data
<
Otype
>
();
const
int
num
=
input
->
dims
()[
0
];
const
int
chin
=
input
->
dims
()[
1
];
const
int
hin
=
input
->
dims
()[
2
];
const
int
win
=
input
->
dims
()[
3
];
const
int
chout
=
output
->
dims
()[
1
];
const
int
hout
=
output
->
dims
()[
2
];
const
int
wout
=
output
->
dims
()[
3
];
const
float
*
weights
=
filter
.
mutable_data
<
float
>
();
const
float
*
bias
=
nullptr
;
int
channel_size_out
=
wout
*
hout
;
int
channel_size_in
=
win
*
hin
;
const
int
group
=
param
.
Groups
();
const
int
m
=
chout
/
group
;
const
int
n
=
hout
*
wout
;
const
int
k
=
chin
/
group
;
bool
flag_relu
=
false
;
bool
flag_bias
=
false
;
ARMArch
arch
=
framework
::
CPUContext
::
Context
()
->
get_arch
();
int
hblock
=
math
::
get_hblock
(
arch
);
int
m_roundup
=
hblock
*
((
m
+
hblock
-
1
)
/
hblock
);
int
weights_size_per_group
=
m
*
k
;
if
(
n
>
1
)
{
weights_size_per_group
=
((
m_roundup
*
k
+
15
)
/
16
)
*
16
;
}
for
(
int
b
=
0
;
b
<
num
;
++
b
)
{
// dC
for
(
int
g
=
0
;
g
<
group
;
++
g
)
{
float
*
dout_group
=
static_cast
<
float
*>
(
dout
)
+
(
b
*
chout
+
g
*
m
)
*
channel_size_out
;
const
float
*
din_group
=
static_cast
<
const
float
*>
(
din
)
+
(
b
*
chin
+
g
*
k
)
*
channel_size_in
;
const
float
*
weights_group
=
static_cast
<
const
float
*>
(
weights
)
+
g
*
weights_size_per_group
;
const
float
*
bias_group
=
static_cast
<
const
float
*>
(
bias
)
+
g
*
m
;
if
(
n
>
1
)
{
math
::
sgemm_prepack
(
weights_group
,
din_group
,
bias_group
,
dout_group
,
m
,
n
,
k
,
flag_bias
,
flag_relu
,
false
,
arch
);
}
}
}
}
template
<
int
tile
,
int
kernel
>
template
<
int
tile
,
int
kernel
>
void
WinogradConv3x3
(
const
ConvParam
<
CPU
>
&
param
)
{
void
WinogradConv3x3
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
input
=
param
.
Input
();
...
@@ -293,6 +350,7 @@ void SlidingwindowConv3x3(const ConvParam<CPU> ¶m) {
...
@@ -293,6 +350,7 @@ void SlidingwindowConv3x3(const ConvParam<CPU> ¶m) {
}
}
template
void
GemmConv
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
GemmConv
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
GemmConv1x1s1
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
WinogradConv3x3
<
8
,
3
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
WinogradConv3x3
<
8
,
3
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
DepthwiseConv3x3
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
DepthwiseConv3x3
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
DepthwiseConv5x5
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
template
void
DepthwiseConv5x5
<
float
,
float
>(
const
ConvParam
<
CPU
>
&
param
);
...
...
src/operators/kernel/central-arm-func/conv_arm_func.h
浏览文件 @
74a309cb
...
@@ -32,6 +32,9 @@ bool IsExpand(const std::vector<int64_t> &filter_dim,
...
@@ -32,6 +32,9 @@ bool IsExpand(const std::vector<int64_t> &filter_dim,
template
<
typename
Itype
,
typename
Otype
>
template
<
typename
Itype
,
typename
Otype
>
void
GemmConv
(
const
ConvParam
<
CPU
>
&
param
);
void
GemmConv
(
const
ConvParam
<
CPU
>
&
param
);
template
<
typename
Itype
,
typename
Otype
>
void
GemmConv1x1s1
(
const
ConvParam
<
CPU
>
&
param
);
template
<
int
tile
,
int
kernel
>
template
<
int
tile
,
int
kernel
>
void
WinogradConv3x3
(
const
ConvParam
<
CPU
>
&
param
);
void
WinogradConv3x3
(
const
ConvParam
<
CPU
>
&
param
);
...
...
src/operators/math/gemm/gemm1x1s1.cpp
0 → 100644
浏览文件 @
74a309cb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#ifdef CONV_OP
#include "operators/math/gemm/gemm1x1s1.h"
#include <arm_neon.h>
#include "framework/context.h"
#include "iostream"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
#ifdef __aarch64__
void
prepackA_8x12
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
m0
,
const
int
mmax
,
const
int
k0
,
const
int
kmax
)
{
int
x_len
=
kmax
-
k0
;
uint32_t
zerobuff
[
x_len
];
memset
(
zerobuff
,
0
,
sizeof
(
uint32_t
)
*
x_len
);
uint32_t
*
dout
=
reinterpret_cast
<
uint32_t
*>
(
out
);
const
uint32_t
*
inptr
=
reinterpret_cast
<
const
uint32_t
*>
(
in
);
int
stride
=
x_len
*
8
;
#pragma omp parallel for
for
(
int
y
=
m0
;
y
<
mmax
;
y
+=
8
)
{
uint32_t
*
outptr
=
dout
+
stride
*
(
y
-
m0
)
/
8
;
const
uint32_t
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
uint32_t
*
inptr1
=
inptr0
+
ldin
;
const
uint32_t
*
inptr2
=
inptr1
+
ldin
;
const
uint32_t
*
inptr3
=
inptr2
+
ldin
;
const
uint32_t
*
inptr4
=
inptr3
+
ldin
;
const
uint32_t
*
inptr5
=
inptr4
+
ldin
;
const
uint32_t
*
inptr6
=
inptr5
+
ldin
;
const
uint32_t
*
inptr7
=
inptr6
+
ldin
;
asm
volatile
(
"prfm pldl1keep, [%[ptr0]]
\n
"
"prfm pldl1keep, [%[ptr0], #64]
\n
"
"prfm pldl1keep, [%[ptr1]]
\n
"
"prfm pldl1keep, [%[ptr1], #64]
\n
"
"prfm pldl1keep, [%[ptr2]]
\n
"
"prfm pldl1keep, [%[ptr2], #64]
\n
"
"prfm pldl1keep, [%[ptr3]]
\n
"
"prfm pldl1keep, [%[ptr3], #64]
\n
"
"prfm pldl1keep, [%[ptr4]]
\n
"
"prfm pldl1keep, [%[ptr4], #64]
\n
"
"prfm pldl1keep, [%[ptr5]]
\n
"
"prfm pldl1keep, [%[ptr5], #64]
\n
"
"prfm pldl1keep, [%[ptr6]]
\n
"
"prfm pldl1keep, [%[ptr6], #64]
\n
"
"prfm pldl1keep, [%[ptr7]]
\n
"
"prfm pldl1keep, [%[ptr7], #64]
\n
"
:
:
[
ptr0
]
"r"
(
inptr0
),
[
ptr1
]
"r"
(
inptr1
),
[
ptr2
]
"r"
(
inptr2
),
[
ptr3
]
"r"
(
inptr3
),
[
ptr4
]
"r"
(
inptr4
),
[
ptr5
]
"r"
(
inptr5
),
[
ptr6
]
"r"
(
inptr6
),
[
ptr7
]
"r"
(
inptr7
)
:
"memory"
);
int
x
=
x_len
;
//! cope with row index exceed real size, set to zero buffer
if
((
y
+
7
)
>=
mmax
)
{
switch
((
y
+
7
)
-
mmax
)
{
case
6
:
inptr1
=
zerobuff
;
case
5
:
inptr2
=
zerobuff
;
case
4
:
inptr3
=
zerobuff
;
case
3
:
inptr4
=
zerobuff
;
case
2
:
inptr5
=
zerobuff
;
case
1
:
inptr6
=
zerobuff
;
case
0
:
inptr7
=
zerobuff
;
default:
break
;
}
}
for
(;
x
>
7
;
x
-=
8
)
{
asm
volatile
(
// Load up 8 elements (2 vectors) from each of 8 sources.
"LDP q0, q1, [%[inptr0]], #32
\n
"
// q0=A0A1A2A3
"LDP q2, q3, [%[inptr1]], #32
\n
"
// q2=B0B1B2B3
"LDP q4, q5, [%[inptr2]], #32
\n
"
// q4=C0C1C2C3
"ZIP1 v16.4s, v0.4s, v4.4s
\n
"
// q16=A0C0A1C1
"prfm pldl1keep, [%[inptr0], #128]
\n
"
"LDP q6, q7, [%[inptr3]], #32
\n
"
// q6=D0D1D2D3
"ZIP1 v17.4s, v2.4s, v6.4s
\n
"
// q17=B0D0B1D1
"LDP q8, q9, [%[inptr4]], #32
\n
"
"LDP q10, q11, [%[inptr5]], #32
\n
"
"LDP q12, q13, [%[inptr6]], #32
\n
"
"ZIP1 v18.4s, v8.4s, v12.4s
\n
"
"prfm pldl1keep, [%[inptr1], #128]
\n
"
"LDP q14, q15, [%[inptr7]], #32
\n
"
"ZIP1 v19.4s, v10.4s, v14.4s
\n
"
"ZIP1 v20.4s, v16.4s, v17.4s
\n
"
// q20=A0B0C0D0
"prfm pldl1keep, [%[inptr2], #128]
\n
"
"ZIP1 v21.4s, v18.4s, v19.4s
\n
"
"ZIP2 v22.4s, v16.4s, v17.4s
\n
"
"ZIP2 v23.4s, v18.4s, v19.4s
\n
"
"ZIP2 v16.4s, v0.4s, v4.4s
\n
"
"prfm pldl1keep, [%[inptr3], #128]
\n
"
"ZIP2 v17.4s, v2.4s, v6.4s
\n
"
"STP q20, q21, [%[outptr]], #32
\n
"
// Write back the first
// element of each source
"ZIP2 v18.4s, v8.4s, v12.4s
\n
"
"ZIP2 v19.4s, v10.4s, v14.4s
\n
"
"STP q22, q23, [%[outptr]], #32
\n
"
// Write back the second
// element of each source
"ZIP1 v20.4s, v16.4s, v17.4s
\n
"
"prfm pldl1keep, [%[inptr4], #128]
\n
"
"ZIP1 v21.4s, v18.4s, v19.4s
\n
"
"ZIP2 v22.4s, v16.4s, v17.4s
\n
"
"ZIP2 v23.4s, v18.4s, v19.4s
\n
"
"ZIP1 v16.4s, v1.4s, v5.4s
\n
"
"prfm pldl1keep, [%[inptr5], #128]
\n
"
"ZIP1 v17.4s, v3.4s, v7.4s
\n
"
"STP q20, q21, [%[outptr]], #32
\n
"
// Third element
"ZIP1 v18.4s, v9.4s, v13.4s
\n
"
"ZIP1 v19.4s, v11.4s, v15.4s
\n
"
"STP q22, q23, [%[outptr]], #32
\n
"
// Fourth element
"ZIP1 v20.4s, v16.4s, v17.4s
\n
"
"ZIP1 v21.4s, v18.4s, v19.4s
\n
"
"ZIP2 v22.4s, v16.4s, v17.4s
\n
"
"prfm pldl1keep, [%[inptr6], #128]
\n
"
"ZIP2 v23.4s, v18.4s, v19.4s
\n
"
"ZIP2 v16.4s, v1.4s, v5.4s
\n
"
"ZIP2 v17.4s, v3.4s, v7.4s
\n
"
"STP q20, q21, [%[outptr]], #32
\n
"
// Fifth element
"ZIP2 v18.4s, v9.4s, v13.4s
\n
"
"prfm pldl1keep, [%[inptr7], #128]
\n
"
"ZIP2 v19.4s, v11.4s, v15.4s
\n
"
"STP q22, q23, [%[outptr]], #32
\n
"
// Sixth element
"ZIP1 v20.4s, v16.4s, v17.4s
\n
"
"ZIP1 v21.4s, v18.4s, v19.4s
\n
"
"STP q20, q21, [%[outptr]], #32
\n
"
// Seventh element
"ZIP2 v22.4s, v16.4s, v17.4s
\n
"
"ZIP2 v23.4s, v18.4s, v19.4s
\n
"
"STP q22, q23, [%[outptr]], #32
\n
"
// Eighth element
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
inptr4
]
"+r"
(
inptr4
),
[
inptr5
]
"+r"
(
inptr5
),
[
inptr6
]
"+r"
(
inptr6
),
[
inptr7
]
"+r"
(
inptr7
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"cc"
,
"memory"
);
}
for
(;
x
>
0
;
x
--
)
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
*
outptr
++
=
*
inptr2
++
;
*
outptr
++
=
*
inptr3
++
;
*
outptr
++
=
*
inptr4
++
;
*
outptr
++
=
*
inptr5
++
;
*
outptr
++
=
*
inptr6
++
;
*
outptr
++
=
*
inptr7
++
;
}
}
}
#else //__aarch64__
void
prepackA_6x8
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
m0
,
const
int
mmax
,
const
int
k0
,
const
int
kmax
)
{
int
x_len
=
kmax
-
k0
;
uint32_t
zerobuff
[
x_len
];
memset
(
zerobuff
,
0
,
sizeof
(
uint32_t
)
*
x_len
);
uint32_t
*
dout
=
reinterpret_cast
<
uint32_t
*>
(
out
);
const
uint32_t
*
inptr
=
reinterpret_cast
<
const
uint32_t
*>
(
in
);
uint32_t
*
outptr
=
dout
;
//! data A is not transposed, transpose A to k * 6
for
(
int
y
=
m0
;
y
<
mmax
;
y
+=
6
)
{
const
uint32_t
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
uint32_t
*
inptr1
=
inptr0
+
ldin
;
const
uint32_t
*
inptr2
=
inptr1
+
ldin
;
const
uint32_t
*
inptr3
=
inptr2
+
ldin
;
const
uint32_t
*
inptr4
=
inptr3
+
ldin
;
const
uint32_t
*
inptr5
=
inptr4
+
ldin
;
int
x
=
x_len
;
//! cope with row index exceed real size, set to zero buffer
if
((
y
+
5
)
>=
mmax
)
{
switch
((
y
+
5
)
-
mmax
)
{
case
4
:
inptr1
=
zerobuff
;
case
3
:
inptr2
=
zerobuff
;
case
2
:
inptr3
=
zerobuff
;
case
1
:
inptr4
=
zerobuff
;
case
0
:
inptr5
=
zerobuff
;
default:
break
;
}
}
for
(;
x
>
7
;
x
-=
8
)
{
//! zip load 8 elements (2 neon Q registers) from each of 6 rows
asm
volatile
(
"vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, "
"q0,q1=r00,r04,r01,r05,r02,r06,r03,r07
\n
"
"vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, "
"q2,q3=r10,r14,r11,r15,r12,r16,r13,r17
\n
"
"vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, "
"q4,q5=r20,r24,r21,r25,r22,r26,r23,r27
\n
"
"vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, "
"q6,q7=r30,r34,r31,r35,r32,r36,r33,r37
\n
"
"vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, "
"q8,q9=r40,r44,r41,r45,r42,r46,r43,r47
\n
"
"vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, "
"q10,q11=r50,r54,r51,r55,r52,r56,r53,r57
\n
"
"vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; "
"q2=r04,r14,r05,r15
\n
"
"vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; "
"q6=r24,r34,r25,r35
\n
"
"vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; "
"q10=r44,r54,r45,r55
\n
"
"vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; "
"q4=r01,r11,r21,r31
\n
"
"vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30
\n
"
"vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50
\n
"
"vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31
\n
"
"vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51
\n
"
"vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; "
"q3=r06,r16,r07,r17
\n
"
"vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; "
"q7=r26,r36,r27,r37
\n
"
"vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; "
"q11=r46,r56,r47,r57
\n
"
"vswp d3, d10 @ swap d3, d10, "
"q1=r02,r12,r22,r32; q5=r03,r13,r23,r33
\n
"
"vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32
\n
"
"vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52
\n
"
"vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33
\n
"
"vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53
\n
"
"vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; "
"q6=r05,r15,r25,r35
\n
"
"vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34
\n
"
"vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54
\n
"
"vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35
\n
"
"vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55
\n
"
"vswp d7, d14 @ swap d7, d14, "
"q3=r06,r16,r26,r36; q7=r07,r17,r27,r37
\n
"
"vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36
\n
"
"vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56
\n
"
"vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37
\n
"
"vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
inptr4
]
"+r"
(
inptr4
),
[
inptr5
]
"+r"
(
inptr5
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"cc"
,
"memory"
);
}
for
(;
x
>
0
;
x
--
)
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
*
outptr
++
=
*
inptr2
++
;
*
outptr
++
=
*
inptr3
++
;
*
outptr
++
=
*
inptr4
++
;
*
outptr
++
=
*
inptr5
++
;
}
}
}
void
prepackA_4x8
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
m0
,
const
int
mmax
,
const
int
k0
,
const
int
kmax
)
{
int
x_len
=
kmax
-
k0
;
uint32_t
zerobuff
[
x_len
];
memset
(
zerobuff
,
0
,
sizeof
(
uint32_t
)
*
x_len
);
uint32_t
*
dout
=
reinterpret_cast
<
uint32_t
*>
(
out
);
const
uint32_t
*
inptr
=
reinterpret_cast
<
const
uint32_t
*>
(
in
);
uint32_t
*
outptr
=
dout
;
//! data A is not transposed, transpose A to k * 4
for
(
int
y
=
m0
;
y
<
mmax
;
y
+=
4
)
{
const
uint32_t
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
uint32_t
*
inptr1
=
inptr0
+
ldin
;
const
uint32_t
*
inptr2
=
inptr1
+
ldin
;
const
uint32_t
*
inptr3
=
inptr2
+
ldin
;
int
x
=
x_len
;
//! cope with row index exceed real size, set to zero buffer
if
((
y
+
3
)
>=
mmax
)
{
switch
((
y
+
3
)
-
mmax
)
{
case
2
:
inptr1
=
zerobuff
;
case
1
:
inptr2
=
zerobuff
;
case
0
:
inptr3
=
zerobuff
;
default:
break
;
}
}
for
(;
x
>
7
;
x
-=
8
)
{
//! zip load 8 elements (2 neon Q registers) from each of 4 rows
asm
volatile
(
"vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, "
"q0,q1=r00,r04,r01,r05,r02,r06,r03,r07
\n
"
"vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, "
"q2,q3=r10,r14,r11,r15,r12,r16,r13,r17
\n
"
"vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, "
"q4,q5=r20,r24,r21,r25,r22,r26,r23,r27
\n
"
"vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, "
"q6,q7=r30,r34,r31,r35,r32,r36,r33,r37
\n
"
"vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; "
"q2=r04,r14,r05,r15
\n
"
"vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; "
"q6=r24,r34,r25,r35
\n
"
"vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; "
"q4=r01,r11,r21,r31
\n
"
"vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30
\n
"
"vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31
\n
"
"vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; "
"q3=r06,r16,r07,r17
\n
"
"vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; "
"q7=r26,r36,r27,r37
\n
"
"vswp d3, d10 @ swap d3, d10, "
"q1=r02,r12,r22,r32; q5=r03,r13,r23,r33
\n
"
"vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32
\n
"
"vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33
\n
"
"vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; "
"q6=r05,r15,r25,r35
\n
"
"vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34
\n
"
"vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35
\n
"
"vswp d7, d14 @ swap d7, d14, "
"q3=r06,r16,r26,r36; q7=r07,r17,r27,r37
\n
"
"vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36
\n
"
"vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"cc"
,
"memory"
);
}
for
(;
x
>
0
;
x
--
)
{
*
outptr
++
=
*
inptr0
++
;
*
outptr
++
=
*
inptr1
++
;
*
outptr
++
=
*
inptr2
++
;
*
outptr
++
=
*
inptr3
++
;
}
}
}
#endif //__aarch64__
void
prepackA
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
m0
,
const
int
mmax
,
const
int
k0
,
const
int
kmax
,
bool
is_trans
,
ARMArch
arch
)
{
#ifdef __aarch64__
if
(
!
is_trans
)
{
prepackA_8x12
(
out
,
in
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
#else
if
(
arch
==
A73
)
{
if
(
!
is_trans
)
{
prepackA_4x8
(
out
,
in
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
}
else
{
if
(
!
is_trans
)
{
prepackA_6x8
(
out
,
in
,
ldin
,
m0
,
mmax
,
k0
,
kmax
);
}
}
#endif
}
void
gemm1x1s1_transform_weight
(
const
framework
::
Tensor
&
weight
,
const
framework
::
Tensor
&
output
,
framework
::
Tensor
*
trans_weight
,
const
int
group
,
ARMArch
arch
)
{
const
int
chout
=
weight
.
dims
()[
0
];
const
int
chin
=
weight
.
dims
()[
1
];
const
int
hout
=
output
.
dims
()[
2
];
const
int
wout
=
output
.
dims
()[
3
];
const
int
m
=
chout
/
group
;
const
int
n
=
hout
*
wout
;
const
int
k
=
chin
/
group
;
if
(
n
>
1
)
{
int
hblock
=
get_hblock
(
arch
);
int
m_roundup
=
hblock
*
((
m
+
hblock
-
1
)
/
hblock
);
int
weights_size_per_group
=
((
m_roundup
*
k
+
15
)
/
16
)
*
16
;
int
weight_worksize
=
sizeof
(
float
)
*
weights_size_per_group
*
group
;
float
*
w_trans_ptr
=
trans_weight
->
mutable_data
<
float
>
({
weight_worksize
});
for
(
int
g
=
0
;
g
<
group
;
++
g
)
{
const
float
*
weights_group
=
weight
.
data
<
float
>
()
+
g
*
m
*
k
;
float
*
weights_trans_ptr
=
w_trans_ptr
+
g
*
weights_size_per_group
;
prepackA
(
weights_trans_ptr
,
weights_group
,
k
,
0
,
m
,
0
,
k
,
false
,
arch
);
}
}
}
#ifdef __aarch64__
void
loadb
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
k0
,
const
int
kmax
,
const
int
n0
,
const
int
nmax
)
{
uint32_t
*
outptr
=
reinterpret_cast
<
uint32_t
*>
(
out
);
const
uint32_t
*
inptr
=
reinterpret_cast
<
const
uint32_t
*>
(
in
)
+
k0
*
ldin
+
n0
;
uint32_t
mask_buffer
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
int
x_len
=
nmax
-
n0
;
int
y_len
=
kmax
-
k0
;
int
right_remain
=
x_len
-
12
*
(
x_len
/
12
);
int
right_pad
=
12
-
right_remain
;
const
size_t
copy_len_remain
=
sizeof
(
float
)
*
right_remain
;
const
size_t
copy_len_pad
=
sizeof
(
float
)
*
right_pad
;
const
size_t
size_ldin
=
sizeof
(
float
)
*
ldin
;
uint32_t
*
outptr_row
=
outptr
;
int
stride_out
=
12
*
y_len
;
uint32x4_t
vzero
=
vdupq_n_u32
(
0
);
uint32x4_t
vmask1
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
),
vdupq_n_u32
(
right_remain
));
uint32x4_t
vmask2
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
+
4
),
vdupq_n_u32
(
right_remain
));
uint32x4_t
vmask3
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
+
8
),
vdupq_n_u32
(
right_remain
));
#pragma omp parallel for
for
(
int
y
=
0
;
y
<
y_len
-
3
;
y
+=
4
)
{
const
uint32_t
*
ptr0
=
inptr
+
y
*
ldin
;
const
uint32_t
*
ptr1
=
ptr0
+
ldin
;
const
uint32_t
*
ptr2
=
ptr1
+
ldin
;
const
uint32_t
*
ptr3
=
ptr2
+
ldin
;
asm
volatile
(
"prfm pldl1keep, [%[ptr0]]
\n
"
"prfm pldl1keep, [%[ptr0], #64]
\n
"
"prfm pldl1keep, [%[ptr1]]
\n
"
"prfm pldl1keep, [%[ptr1], #64]
\n
"
"prfm pldl1keep, [%[ptr2]]
\n
"
"prfm pldl1keep, [%[ptr2], #64]
\n
"
"prfm pldl1keep, [%[ptr3]]
\n
"
"prfm pldl1keep, [%[ptr3], #64]
\n
"
:
:
[
ptr0
]
"r"
(
ptr0
),
[
ptr1
]
"r"
(
ptr1
),
[
ptr2
]
"r"
(
ptr2
),
[
ptr3
]
"r"
(
ptr3
)
:
"memory"
);
uint32_t
*
outptr_row_col
=
outptr_row
+
y
*
12
;
int
i
=
0
;
for
(;
i
<
x_len
-
11
;
i
+=
12
)
{
uint32x4_t
vr00
=
vld1q_u32
(
ptr0
);
uint32x4_t
vr01
=
vld1q_u32
(
ptr0
+
4
);
uint32x4_t
vr02
=
vld1q_u32
(
ptr0
+
8
);
uint32x4_t
vr10
=
vld1q_u32
(
ptr1
);
uint32x4_t
vr11
=
vld1q_u32
(
ptr1
+
4
);
uint32x4_t
vr12
=
vld1q_u32
(
ptr1
+
8
);
vst1q_u32
(
outptr_row_col
,
vr00
);
vst1q_u32
(
outptr_row_col
+
4
,
vr01
);
vst1q_u32
(
outptr_row_col
+
8
,
vr02
);
uint32x4_t
vr20
=
vld1q_u32
(
ptr2
);
uint32x4_t
vr21
=
vld1q_u32
(
ptr2
+
4
);
uint32x4_t
vr22
=
vld1q_u32
(
ptr2
+
8
);
vst1q_u32
(
outptr_row_col
+
12
,
vr10
);
vst1q_u32
(
outptr_row_col
+
16
,
vr11
);
vst1q_u32
(
outptr_row_col
+
20
,
vr12
);
uint32x4_t
vr30
=
vld1q_u32
(
ptr3
);
uint32x4_t
vr31
=
vld1q_u32
(
ptr3
+
4
);
uint32x4_t
vr32
=
vld1q_u32
(
ptr3
+
8
);
vst1q_u32
(
outptr_row_col
+
24
,
vr20
);
vst1q_u32
(
outptr_row_col
+
28
,
vr21
);
vst1q_u32
(
outptr_row_col
+
32
,
vr22
);
vst1q_u32
(
outptr_row_col
+
36
,
vr30
);
vst1q_u32
(
outptr_row_col
+
40
,
vr31
);
vst1q_u32
(
outptr_row_col
+
44
,
vr32
);
ptr0
+=
12
;
ptr1
+=
12
;
ptr2
+=
12
;
ptr3
+=
12
;
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
uint32x4_t
vr00
=
vld1q_u32
(
ptr0
);
uint32x4_t
vr01
=
vld1q_u32
(
ptr0
+
4
);
uint32x4_t
vr02
=
vld1q_u32
(
ptr0
+
8
);
uint32x4_t
vr10
=
vld1q_u32
(
ptr1
);
uint32x4_t
vr11
=
vld1q_u32
(
ptr1
+
4
);
uint32x4_t
vr12
=
vld1q_u32
(
ptr1
+
8
);
uint32x4_t
vr00_1
=
vbslq_u32
(
vmask1
,
vr00
,
vzero
);
uint32x4_t
vr01_1
=
vbslq_u32
(
vmask2
,
vr01
,
vzero
);
uint32x4_t
vr02_1
=
vbslq_u32
(
vmask3
,
vr02
,
vzero
);
uint32x4_t
vr20
=
vld1q_u32
(
ptr2
);
uint32x4_t
vr21
=
vld1q_u32
(
ptr2
+
4
);
uint32x4_t
vr22
=
vld1q_u32
(
ptr2
+
8
);
vst1q_u32
(
outptr_row_col
,
vr00_1
);
vst1q_u32
(
outptr_row_col
+
4
,
vr01_1
);
vst1q_u32
(
outptr_row_col
+
8
,
vr02_1
);
uint32x4_t
vr10_1
=
vbslq_u32
(
vmask1
,
vr10
,
vzero
);
uint32x4_t
vr11_1
=
vbslq_u32
(
vmask2
,
vr11
,
vzero
);
uint32x4_t
vr12_1
=
vbslq_u32
(
vmask3
,
vr12
,
vzero
);
uint32x4_t
vr30
=
vld1q_u32
(
ptr3
);
uint32x4_t
vr31
=
vld1q_u32
(
ptr3
+
4
);
uint32x4_t
vr32
=
vld1q_u32
(
ptr3
+
8
);
vst1q_u32
(
outptr_row_col
+
12
,
vr10_1
);
vst1q_u32
(
outptr_row_col
+
16
,
vr11_1
);
vst1q_u32
(
outptr_row_col
+
20
,
vr12_1
);
uint32x4_t
vr20_1
=
vbslq_u32
(
vmask1
,
vr20
,
vzero
);
uint32x4_t
vr21_1
=
vbslq_u32
(
vmask2
,
vr21
,
vzero
);
uint32x4_t
vr22_1
=
vbslq_u32
(
vmask3
,
vr22
,
vzero
);
uint32x4_t
vr30_1
=
vbslq_u32
(
vmask1
,
vr30
,
vzero
);
uint32x4_t
vr31_1
=
vbslq_u32
(
vmask2
,
vr31
,
vzero
);
uint32x4_t
vr32_1
=
vbslq_u32
(
vmask3
,
vr32
,
vzero
);
vst1q_u32
(
outptr_row_col
+
24
,
vr20_1
);
vst1q_u32
(
outptr_row_col
+
28
,
vr21_1
);
vst1q_u32
(
outptr_row_col
+
32
,
vr22_1
);
vst1q_u32
(
outptr_row_col
+
36
,
vr30_1
);
vst1q_u32
(
outptr_row_col
+
40
,
vr31_1
);
vst1q_u32
(
outptr_row_col
+
44
,
vr32_1
);
}
}
#pragma omp parallel for
for
(
int
y
=
4
*
(
y_len
/
4
);
y
<
y_len
;
++
y
)
{
const
uint32_t
*
ptr0
=
inptr
+
y
*
ldin
;
uint32_t
*
outptr_row_col
=
outptr_row
+
y
*
12
;
int
i
=
0
;
for
(;
i
<
x_len
-
11
;
i
+=
12
)
{
uint32x4_t
vr0
=
vld1q_u32
(
ptr0
);
uint32x4_t
vr1
=
vld1q_u32
(
ptr0
+
4
);
uint32x4_t
vr2
=
vld1q_u32
(
ptr0
+
8
);
vst1q_u32
(
outptr_row_col
,
vr0
);
vst1q_u32
(
outptr_row_col
+
4
,
vr1
);
vst1q_u32
(
outptr_row_col
+
8
,
vr2
);
ptr0
+=
12
;
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
uint32x4_t
vr0
=
vld1q_u32
(
ptr0
);
uint32x4_t
vr1
=
vld1q_u32
(
ptr0
+
4
);
uint32x4_t
vr2
=
vld1q_u32
(
ptr0
+
8
);
uint32x4_t
vr0_1
=
vbslq_u32
(
vmask1
,
vr0
,
vzero
);
uint32x4_t
vr1_1
=
vbslq_u32
(
vmask2
,
vr1
,
vzero
);
uint32x4_t
vr2_1
=
vbslq_u32
(
vmask3
,
vr2
,
vzero
);
vst1q_u32
(
outptr_row_col
,
vr0_1
);
vst1q_u32
(
outptr_row_col
+
4
,
vr1_1
);
vst1q_u32
(
outptr_row_col
+
8
,
vr2_1
);
}
}
}
#else //__aarch64__
void
loadb
(
float
*
out
,
const
float
*
in
,
const
int
ldin
,
const
int
k0
,
const
int
kmax
,
const
int
n0
,
const
int
nmax
)
{
uint32_t
*
outptr
=
reinterpret_cast
<
uint32_t
*>
(
out
);
const
uint32_t
*
inptr
=
reinterpret_cast
<
const
uint32_t
*>
(
in
)
+
k0
*
ldin
+
n0
;
uint32_t
mask_buffer
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
int
x_len
=
nmax
-
n0
;
int
y_len
=
kmax
-
k0
;
int
right_remain
=
x_len
-
8
*
(
x_len
/
8
);
int
right_pad
=
8
-
right_remain
;
const
size_t
copy_len_remain
=
sizeof
(
float
)
*
right_remain
;
const
size_t
copy_len_pad
=
sizeof
(
float
)
*
right_pad
;
const
size_t
size_ldin
=
sizeof
(
float
)
*
ldin
;
uint32_t
*
outptr_row
=
outptr
;
int
stride_out
=
8
*
y_len
;
uint32x4_t
vzero
=
vdupq_n_u32
(
0
);
uint32x4_t
vmask1
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
),
vdupq_n_u32
(
right_remain
));
uint32x4_t
vmask2
=
vcltq_u32
(
vld1q_u32
(
mask_buffer
+
4
),
vdupq_n_u32
(
right_remain
));
#pragma omp parallel for
for
(
int
y
=
0
;
y
<
y_len
-
3
;
y
+=
4
)
{
const
uint32_t
*
ptr0
=
inptr
+
y
*
ldin
;
const
uint32_t
*
ptr1
=
ptr0
+
ldin
;
const
uint32_t
*
ptr2
=
ptr1
+
ldin
;
const
uint32_t
*
ptr3
=
ptr2
+
ldin
;
uint32_t
*
outptr_row_col
=
outptr_row
+
y
*
8
;
int
i
=
0
;
for
(;
i
<
x_len
-
7
;
i
+=
8
)
{
uint32_t
*
ptr_out
=
outptr_row_col
;
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements
\n
"
"vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr
\n
"
"vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements
\n
"
"vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr
\n
"
:
[
outptr
]
"+r"
(
ptr_out
),
[
ptr0
]
"+r"
(
ptr0
),
[
ptr1
]
"+r"
(
ptr1
),
[
ptr2
]
"+r"
(
ptr2
),
[
ptr3
]
"+r"
(
ptr3
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"cc"
,
"memory"
);
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
uint32_t
*
ptr_out
=
outptr_row_col
;
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements
\n
"
"vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements
\n
"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero
\n
"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero
\n
"
//"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero
\n
"
"vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr
\n
"
"vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements
\n
"
"vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements
\n
"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero
\n
"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero
\n
"
//"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n"
"vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero
\n
"
"vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
"vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr
\n
"
:
[
outptr
]
"+r"
(
ptr_out
),
[
ptr0
]
"+r"
(
ptr0
),
[
ptr1
]
"+r"
(
ptr1
),
[
ptr2
]
"+r"
(
ptr2
),
[
ptr3
]
"+r"
(
ptr3
)
:
[
vmask1
]
"w"
(
vmask1
),
[
vmask2
]
"w"
(
vmask2
),
[
vzero
]
"w"
(
vzero
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"cc"
,
"memory"
);
}
}
#pragma omp parallel for
for
(
int
y
=
4
*
(
y_len
/
4
);
y
<
y_len
;
++
y
)
{
const
uint32_t
*
ptr0
=
inptr
+
y
*
ldin
;
uint32_t
*
outptr_row_col
=
outptr_row
+
y
*
8
;
int
i
=
0
;
for
(;
i
<
x_len
-
7
;
i
+=
8
)
{
uint32_t
*
ptr_out
=
outptr_row_col
;
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
:
[
ptr0
]
"+r"
(
ptr0
),
[
outptr
]
"+r"
(
ptr_out
)
:
:
"q0"
,
"q1"
,
"cc"
,
"memory"
);
outptr_row_col
+=
stride_out
;
}
if
(
right_remain
>
0
)
{
uint32_t
*
ptr_out
=
outptr_row_col
;
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements
\n
"
"vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero
\n
"
"vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero
\n
"
"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr
\n
"
:
[
ptr0
]
"+r"
(
ptr0
),
[
outptr
]
"+r"
(
ptr_out
)
:
[
vmask1
]
"w"
(
vmask1
),
[
vmask2
]
"w"
(
vmask2
),
[
vzero
]
"w"
(
vzero
)
:
"q0"
,
"q1"
,
"cc"
,
"memory"
);
}
}
}
#endif //__aarch64__
#ifdef __aarch64__
void
sgemm_conv_8x12
(
const
float
*
A_packed
,
const
float
*
B
,
const
float
*
bias
,
float
*
C
,
int
M
,
int
N
,
int
K
,
bool
is_bias
,
bool
is_relu
,
bool
transB
)
{
const
int
threads
=
framework
::
CPUContext
::
Context
()
->
get_thread_num
();
int
l2_size
=
framework
::
CPUContext
::
Context
()
->
get_l2_cache_size
()
/
sizeof
(
float
);
int
l2_cache
=
l2_size
>
0
?
l2_size
:
512
*
1024
;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int
x_block
=
(
l2_cache
-
(
MBLOCK
*
K
))
/
(
sizeof
(
float
)
*
(
K
+
MBLOCK
));
x_block
/=
NBLOCK
;
x_block
*=
NBLOCK
;
int
x_num
=
(
N
+
(
x_block
-
1
))
/
x_block
;
x_block
=
(
N
+
x_num
-
1
)
/
x_num
;
x_block
=
(
x_block
+
NBLOCK
-
1
)
/
NBLOCK
;
x_block
*=
NBLOCK
;
x_block
=
x_block
<
NBLOCK
?
NBLOCK
:
x_block
;
// unroll 2 loop
int
tail_pre
=
(
K
&
(
KBLOCK
-
1
));
int
k_pre
=
((
K
+
KBLOCK
-
1
)
/
KBLOCK
)
-
1
;
bool
flag_p_remain
=
false
;
int
remain
=
0
;
//! apanel is pre_compute outside gemm
for
(
unsigned
int
x0
=
0
;
x0
<
N
;
x0
+=
x_block
)
{
unsigned
int
xmax
=
x0
+
x_block
;
if
(
xmax
>
N
)
{
xmax
=
N
;
}
int
bblocks
=
(
xmax
-
x0
+
NBLOCK
-
1
)
/
NBLOCK
;
remain
=
xmax
-
x0
-
(
bblocks
-
1
)
*
NBLOCK
;
if
(
remain
>
0
)
{
flag_p_remain
=
true
;
}
//! load bpanel
float
*
b_pannel
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
K
*
(
xmax
-
x0
)
*
sizeof
(
float
)));
if
(
!
transB
)
{
loadb
(
b_pannel
,
B
,
N
,
0
,
K
,
x0
,
xmax
);
}
#pragma omp parallel for num_threads(threads)
for
(
unsigned
int
y
=
0
;
y
<
M
;
y
+=
MBLOCK
)
{
unsigned
int
ymax
=
y
+
MBLOCK
;
if
(
ymax
>
M
)
{
ymax
=
M
;
}
float
bias_local
[
8
]
=
{
0
};
if
(
is_bias
)
{
bias_local
[
0
]
=
bias
[
y
];
bias_local
[
1
]
=
bias
[
y
+
1
];
bias_local
[
2
]
=
bias
[
y
+
2
];
bias_local
[
3
]
=
bias
[
y
+
3
];
bias_local
[
4
]
=
bias
[
y
+
4
];
bias_local
[
5
]
=
bias
[
y
+
5
];
bias_local
[
6
]
=
bias
[
y
+
6
];
bias_local
[
7
]
=
bias
[
y
+
7
];
}
float
cout0
[
NBLOCK
];
float
cout1
[
NBLOCK
];
float
cout2
[
NBLOCK
];
float
cout3
[
NBLOCK
];
float
cout4
[
NBLOCK
];
float
cout5
[
NBLOCK
];
float
cout6
[
NBLOCK
];
float
cout7
[
NBLOCK
];
float
*
c_ptr0
=
C
+
y
*
N
+
x0
;
float
*
c_ptr1
=
c_ptr0
+
N
;
float
*
c_ptr2
=
c_ptr1
+
N
;
float
*
c_ptr3
=
c_ptr2
+
N
;
float
*
c_ptr4
=
c_ptr3
+
N
;
float
*
c_ptr5
=
c_ptr4
+
N
;
float
*
c_ptr6
=
c_ptr5
+
N
;
float
*
c_ptr7
=
c_ptr6
+
N
;
float
*
pout0
=
c_ptr0
;
float
*
pout1
=
c_ptr1
;
float
*
pout2
=
c_ptr2
;
float
*
pout3
=
c_ptr3
;
float
*
pout4
=
c_ptr4
;
float
*
pout5
=
c_ptr5
;
float
*
pout6
=
c_ptr6
;
float
*
pout7
=
c_ptr7
;
const
float
*
a_ptr_l
=
A_packed
+
y
*
K
;
const
float
*
b_ptr
=
b_pannel
;
for
(
int
xb
=
0
;
xb
<
bblocks
;
xb
++
)
{
if
((
y
+
7
)
>=
ymax
)
{
switch
((
y
+
7
)
-
ymax
)
{
case
6
:
c_ptr1
=
cout1
;
case
5
:
c_ptr2
=
cout2
;
case
4
:
c_ptr3
=
cout3
;
case
3
:
c_ptr4
=
cout4
;
case
2
:
c_ptr5
=
cout5
;
case
1
:
c_ptr6
=
cout6
;
case
0
:
c_ptr7
=
cout7
;
default:
break
;
}
}
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
pout0
=
c_ptr0
;
pout1
=
c_ptr1
;
pout2
=
c_ptr2
;
pout3
=
c_ptr3
;
pout4
=
c_ptr4
;
pout5
=
c_ptr5
;
pout6
=
c_ptr6
;
pout7
=
c_ptr7
;
c_ptr0
=
cout0
;
c_ptr1
=
cout1
;
c_ptr2
=
cout2
;
c_ptr3
=
cout3
;
c_ptr4
=
cout4
;
c_ptr5
=
cout5
;
c_ptr6
=
cout6
;
c_ptr7
=
cout7
;
}
const
float
*
a_ptr
=
a_ptr_l
;
int
tail
=
tail_pre
;
int
k
=
k_pre
;
asm
volatile
(
// Initialize result registers, load initial operands, prime
// prefetches.
"ldp q2, q3, [%[bias_ptr]]
\n
"
/* load bias to q2, q3*/
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00,a01 to q0, q1*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q4, q5*/
"dup v8.4s, v2.s[0]
\n
"
/* out0 = 0 */
"dup v9.4s, v2.s[0]
\n
"
/* out1 = 0*/
"dup v10.4s, v2.s[0]
\n
"
/* out2 = 0*/
"dup v11.4s, v2.s[1]
\n
"
/* out3 = 0*/
"dup v12.4s, v2.s[1]
\n
"
/* out4 = 0*/
"prfm pldl1keep, [%[b_ptr], #64]
\n
"
/* preload b*/
"dup v13.4s, v2.s[1]
\n
"
/* out5 = 0*/
"prfm pldl1keep, [%[a_ptr], #64]
\n
"
/* preload a*/
"dup v14.4s, v2.s[2]
\n
"
/* out6 = 0*/
"prfm pldl1keep, [%[b_ptr], #128]
\n
"
/* preload b*/
"dup v15.4s, v2.s[2]
\n
"
/* out7 = 0*/
"prfm pldl1keep, [%[a_ptr], #128]
\n
"
/* preload a*/
"dup v16.4s, v2.s[2]
\n
"
/* out8 = 0*/
"prfm pldl1keep, [%[b_ptr], #192]
\n
"
/* preload b*/
"dup v17.4s, v2.s[3]
\n
"
/* out9 = 0*/
"prfm pldl1keep, [%[b_ptr], #256]
\n
"
/* preload b*/
"dup v18.4s, v2.s[3]
\n
"
/* out10 = 0*/
"prfm pldl1keep, [%[a_ptr], #192]
\n
"
/* preload a*/
"dup v19.4s, v2.s[3]
\n
"
/* out11 = 0*/
"prfm pldl1keep, [%[b_ptr], #320]
\n
"
/* preload b*/
"dup v20.4s, v3.s[0]
\n
"
/* out12 = 0*/
"prfm pldl1keep, [%[a_ptr], #256]
\n
"
/* preload a*/
"dup v21.4s, v3.s[0]
\n
"
/* out13 = 0*/
"prfm pldl1keep, [%[b_ptr], #384]
\n
"
/* preload b*/
"dup v22.4s, v3.s[0]
\n
"
/* out14 = 0*/
"dup v23.4s, v3.s[1]
\n
"
/* out15 = 0*/
"dup v24.4s, v3.s[1]
\n
"
/* out16 = 0*/
"dup v25.4s, v3.s[1]
\n
"
/* out17 = 0*/
"dup v26.4s, v3.s[2]
\n
"
/* out18 = 0*/
"dup v27.4s, v3.s[2]
\n
"
/* out19 = 0*/
"dup v28.4s, v3.s[2]
\n
"
/* out20 = 0*/
"dup v29.4s, v3.s[3]
\n
"
/* out21 = 0*/
"dup v30.4s, v3.s[3]
\n
"
/* out22 = 0*/
"dup v31.4s, v3.s[3]
\n
"
/* out23 = 0*/
"cbz %w[k], 2f
\n
"
/* check loop count > 0 */
/* main loop */
/* unrool 0*/
"1:
\n
"
/* main loop */
"fmla v8.4s , v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 = q4
*/
"fmla v11.4s , v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 = q4
*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b2, b0 to q6, q7 */
"fmla v14.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 = q4
*/
"fmla v17.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 = q4
*/
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a10, a11 to q3, q4 */
"fmla v20.4s, v4.4s, v1.s[0]
\n
"
/* out4 = b0 * a01[0], b0 = q4
*/
"fmla v23.4s, v4.4s, v1.s[1]
\n
"
/* out5 = b0 * a01[1], b0 = q4
*/
"fmla v26.4s, v4.4s, v1.s[2]
\n
"
/* out6 = b0 * a01[2], b0 = q4
*/
"fmla v29.4s, v4.4s, v1.s[3]
\n
"
/* out7 = b0 * a01[3], b0 = q4
*/
"fmla v9.4s, v5.4s, v0.s[0]
\n
"
/* out8 = b1 * a00[0], b1 = q5 */
"fmla v12.4s, v5.4s, v0.s[1]
\n
"
/* out9 = b1 * a00[1], b1 = q5
*/
"fmla v15.4s, v5.4s, v0.s[2]
\n
"
/* out10 = b1 * a00[2], b1 =
q5*/
"fmla v18.4s, v5.4s, v0.s[3]
\n
"
/* out11 = b1 * a00[3], b1 =
q5*/
"fmla v21.4s, v5.4s, v1.s[0]
\n
"
/* out12 = b1 * a01[0], b1 =
q5*/
"fmla v24.4s, v5.4s, v1.s[1]
\n
"
/* out13 = b1 * a01[1], b1 =
q5*/
"fmla v27.4s, v5.4s, v1.s[2]
\n
"
/* out14 = b1 * a01[2], b1 =
q5*/
"fmla v30.4s, v5.4s, v1.s[3]
\n
"
/* out15 = b1 * a01[3], b1 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b1, b2 to q4, q5 */
"fmla v10.4s, v6.4s, v0.s[0]
\n
"
/* out16 = b2 * a00[0], b2 =
q6*/
"fmla v13.4s, v6.4s, v0.s[1]
\n
"
/* out17 = b2 * a00[1], b2 =
q6*/
"prfm pldl1keep, [%[b_ptr], #384]
\n
"
"fmla v16.4s, v6.4s, v0.s[2]
\n
"
/* out18 = b2 * a00[2], b2 =
q6*/
"fmla v19.4s, v6.4s, v0.s[3]
\n
"
/* out19 = b2 * a00[3], b2 =
q6*/
"fmla v22.4s, v6.4s, v1.s[0]
\n
"
/* out20 = b2 * a00[0], b2 =
q6*/
"fmla v25.4s, v6.4s, v1.s[1]
\n
"
/* out21 = b2 * a00[1], b2 =
q6*/
"fmla v28.4s, v6.4s, v1.s[2]
\n
"
/* out22 = b2 * a00[2], b2 =
q6*/
"fmla v31.4s, v6.4s, v1.s[3]
\n
"
/* out23 = b2 * a00[3], b2 =
q6*/
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00, a01 to q0, q1 */
/* unrool 1 */
"fmla v8.4s , v7.4s, v2.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q7
*/
"fmla v11.4s , v7.4s, v2.s[1]
\n
"
/* out1 = b0 * a10[1], b0 = q7
*/
"fmla v14.4s, v7.4s, v2.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q7
*/
"prfm pldl1keep, [%[a_ptr], #256]
\n
"
"fmla v17.4s, v7.4s, v2.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q7
*/
"fmla v20.4s, v7.4s, v3.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q7
*/
"fmla v23.4s, v7.4s, v3.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q7
*/
"fmla v26.4s, v7.4s, v3.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q7
*/
"fmla v29.4s, v7.4s, v3.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q7
*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q6, q7 */
"fmla v9.4s, v4.4s, v2.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q4 */
"fmla v12.4s, v4.4s, v2.s[1]
\n
"
/* out9 = b0 * a10[1], b1 = q4
*/
"fmla v15.4s, v4.4s, v2.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q4*/
"fmla v18.4s, v4.4s, v2.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q4*/
"fmla v21.4s, v4.4s, v3.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q4*/
"fmla v24.4s, v4.4s, v3.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q4*/
"fmla v27.4s, v4.4s, v3.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q4*/
"fmla v30.4s, v4.4s, v3.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q4*/
"fmla v10.4s, v5.4s, v2.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q5*/
"fmla v13.4s, v5.4s, v2.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q5*/
"fmla v16.4s, v5.4s, v2.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q5*/
"fmla v19.4s, v5.4s, v2.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q5*/
"fmla v22.4s, v5.4s, v3.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q5*/
"fmla v25.4s, v5.4s, v3.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q5*/
"fmla v28.4s, v5.4s, v3.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q5*/
"fmla v31.4s, v5.4s, v3.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b2, b0 to q4, q5 */
/* unrool 2*/
"fmla v8.4s , v6.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 = q6
*/
"fmla v11.4s , v6.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 = q6
*/
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a10, a11 to q3, q4*/
"fmla v14.4s, v6.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 = q6*/
"fmla v17.4s, v6.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 = q6*/
"fmla v20.4s, v6.4s, v1.s[0]
\n
"
/* out4 = b0 * a01[0], b0 = q6*/
"fmla v23.4s, v6.4s, v1.s[1]
\n
"
/* out5 = b0 * a01[1], b0 = q6*/
"fmla v26.4s, v6.4s, v1.s[2]
\n
"
/* out6 = b0 * a01[2], b0 = q6*/
"fmla v29.4s, v6.4s, v1.s[3]
\n
"
/* out7 = b0 * a01[3], b0 = q6*/
"fmla v9.4s, v7.4s, v0.s[0]
\n
"
/* out8 = b1 * a00[0], b1 = q7*/
"fmla v12.4s, v7.4s, v0.s[1]
\n
"
/* out9 = b1 * a00[1], b1 = q7*/
"prfm pldl1keep, [%[b_ptr], #384]
\n
"
"fmla v15.4s, v7.4s, v0.s[2]
\n
"
/* out10 = b1 * a00[2], b1 =
q7*/
"fmla v18.4s, v7.4s, v0.s[3]
\n
"
/* out11 = b1 * a00[3], b1 =
q7*/
"fmla v21.4s, v7.4s, v1.s[0]
\n
"
/* out12 = b1 * a01[0], b1 =
q7*/
"fmla v24.4s, v7.4s, v1.s[1]
\n
"
/* out13 = b1 * a01[1], b1 =
q7*/
"fmla v27.4s, v7.4s, v1.s[2]
\n
"
/* out14 = b1 * a01[2], b1 =
q7*/
"fmla v30.4s, v7.4s, v1.s[3]
\n
"
/* out15 = b1 * a01[3], b1 =
q7*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b1, b2 to q6, q7*/
"fmla v10.4s, v4.4s, v0.s[0]
\n
"
/* out16 = b2 * a00[0], b2 =
q4*/
"fmla v13.4s, v4.4s, v0.s[1]
\n
"
/* out17 = b2 * a00[1], b2 =
q4*/
"fmla v16.4s, v4.4s, v0.s[2]
\n
"
/* out18 = b2 * a00[2], b2 =
q4*/
"fmla v19.4s, v4.4s, v0.s[3]
\n
"
/* out19 = b2 * a00[3], b2 =
q4*/
"fmla v22.4s, v4.4s, v1.s[0]
\n
"
/* out20 = b2 * a00[0], b2 =
q4*/
"fmla v25.4s, v4.4s, v1.s[1]
\n
"
/* out21 = b2 * a00[1], b2 =
q4*/
"fmla v28.4s, v4.4s, v1.s[2]
\n
"
/* out22 = b2 * a00[2], b2 =
q4*/
"fmla v31.4s, v4.4s, v1.s[3]
\n
"
/* out23 = b2 * a00[3], b2 =
q4*/
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00, a01 to q0, q1*/
/* unrool 3*/
"fmla v8.4s , v5.4s, v2.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q5*/
"fmla v11.4s , v5.4s, v2.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v5.4s, v2.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q5*/
"fmla v17.4s, v5.4s, v2.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q5*/
"fmla v20.4s, v5.4s, v3.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q5*/
"fmla v23.4s, v5.4s, v3.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v5.4s, v3.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q5*/
"fmla v29.4s, v5.4s, v3.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q5*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q4, q5*/
"fmla v9.4s, v6.4s, v2.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q6*/
"fmla v12.4s, v6.4s, v2.s[1]
\n
"
/* out9 = b0 * a10[1], b1 = q6*/
"prfm pldl1keep, [%[a_ptr], #256]
\n
"
"fmla v15.4s, v6.4s, v2.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v6.4s, v2.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v6.4s, v3.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v6.4s, v3.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v6.4s, v3.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q6*/
"prfm pldl1keep, [%[b_ptr], #384]
\n
"
"fmla v30.4s, v6.4s, v3.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v7.4s, v2.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v7.4s, v2.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v7.4s, v2.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v7.4s, v2.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v7.4s, v3.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v7.4s, v3.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q7*/
"subs %w[k], %w[k], #1
\n
"
/* loop count - 1*/
"fmla v28.4s, v7.4s, v3.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v7.4s, v3.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q7*/
"bne 1b
\n
"
/* Target to use when K is 1 or 2 (i.e. zero iterations of main
loop)*/
"2:
\n
"
/* process tail*/
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"beq 3f
\n
"
/*jump to tail = 1*/
/* final unrool 0*/
/* unrool 0, tail > 1*/
"fmla v8.4s , v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 = q4*/
"fmla v11.4s , v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 =
q4*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b2, b0 to q6, q7*/
"fmla v14.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 = q4*/
"fmla v17.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 = q4*/
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a10, a11 to q2, q3*/
"fmla v20.4s, v4.4s, v1.s[0]
\n
"
/* out4 = b0 * a01[0], b0 = q4*/
"fmla v23.4s, v4.4s, v1.s[1]
\n
"
/* out5 = b0 * a01[1], b0 = q4*/
"fmla v26.4s, v4.4s, v1.s[2]
\n
"
/* out6 = b0 * a01[2], b0 = q4*/
"fmla v29.4s, v4.4s, v1.s[3]
\n
"
/* out7 = b0 * a01[3], b0 = q4*/
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"fmla v9.4s, v5.4s, v0.s[0]
\n
"
/* out8 = b1 * a00[0], b1 = q5*/
"fmla v12.4s, v5.4s, v0.s[1]
\n
"
/* out9 = b1 * a00[1], b1 = q5*/
"fmla v15.4s, v5.4s, v0.s[2]
\n
"
/* out10 = b1 * a00[2], b1 =
q5*/
"fmla v18.4s, v5.4s, v0.s[3]
\n
"
/* out11 = b1 * a00[3], b1 =
q5*/
"fmla v21.4s, v5.4s, v1.s[0]
\n
"
/* out12 = b1 * a01[0], b1 =
q5*/
"fmla v24.4s, v5.4s, v1.s[1]
\n
"
/* out13 = b1 * a01[1], b1 =
q5*/
"fmla v27.4s, v5.4s, v1.s[2]
\n
"
/* out14 = b1 * a01[2], b1 =
q5*/
"fmla v30.4s, v5.4s, v1.s[3]
\n
"
/* out15 = b1 * a01[3], b1 =
q5*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b1, b2 to q4, q5*/
"fmla v10.4s, v6.4s, v0.s[0]
\n
"
/* out16 = b2 * a00[0], b2 =
q6*/
"fmla v13.4s, v6.4s, v0.s[1]
\n
"
/* out17 = b2 * a00[1], b2 =
q6*/
"fmla v16.4s, v6.4s, v0.s[2]
\n
"
/* out18 = b2 * a00[2], b2 =
q6*/
"fmla v19.4s, v6.4s, v0.s[3]
\n
"
/* out19 = b2 * a00[3], b2 =
q6*/
"fmla v22.4s, v6.4s, v1.s[0]
\n
"
/* out20 = b2 * a00[0], b2 =
q6*/
"fmla v25.4s, v6.4s, v1.s[1]
\n
"
/* out21 = b2 * a00[1], b2 =
q6*/
"fmla v28.4s, v6.4s, v1.s[2]
\n
"
/* out22 = b2 * a00[2], b2 =
q6*/
"fmla v31.4s, v6.4s, v1.s[3]
\n
"
/* out23 = b2 * a00[3], b2 =
q6*/
"beq 4f
\n
"
/*jump to tail = 2*/
/* unrool 1, tail > 2*/
"ldp q0, q1, [%[a_ptr]], #32
\n
"
/* load a00, a01 to q0, q1*/
"fmla v8.4s , v7.4s, v2.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q7*/
"fmla v11.4s , v7.4s, v2.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q7*/
"fmla v14.4s, v7.4s, v2.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q7*/
"fmla v17.4s, v7.4s, v2.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q7*/
"fmla v20.4s, v7.4s, v3.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q7*/
"fmla v23.4s, v7.4s, v3.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q7*/
"fmla v26.4s, v7.4s, v3.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q7*/
"fmla v29.4s, v7.4s, v3.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q7*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b0, b1 to q6, q7*/
"fmla v9.4s, v4.4s, v2.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q4*/
"fmla v12.4s, v4.4s, v2.s[1]
\n
"
/* out9 = b0 * a10[1], b1 = q4*/
"fmla v15.4s, v4.4s, v2.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q4*/
"fmla v18.4s, v4.4s, v2.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q4*/
"fmla v21.4s, v4.4s, v3.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q4*/
"fmla v24.4s, v4.4s, v3.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q4*/
"fmla v27.4s, v4.4s, v3.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q4*/
"fmla v30.4s, v4.4s, v3.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q4*/
"subs %w[tail], %w[tail], #1
\n
"
/* tail--*/
"fmla v10.4s, v5.4s, v2.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q5*/
"fmla v13.4s, v5.4s, v2.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q5*/
"fmla v16.4s, v5.4s, v2.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q5*/
"fmla v19.4s, v5.4s, v2.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q5*/
"fmla v22.4s, v5.4s, v3.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q5*/
"fmla v25.4s, v5.4s, v3.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q5*/
"fmla v28.4s, v5.4s, v3.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q5*/
"fmla v31.4s, v5.4s, v3.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q5*/
"beq 5f
\n
"
/*jump to tail = 3*/
/* unrool 2, tail = 4*/
"ldp q4, q5, [%[b_ptr]], #32
\n
"
/* load b2, b0 to q4, q5*/
"fmla v8.4s , v6.4s, v0.s[0]
\n
"
/* out0 = b0 * a00[0], b0 = q6*/
"fmla v11.4s , v6.4s, v0.s[1]
\n
"
/* out1 = b0 * a00[1], b0 =
q6*/
"ldp q2, q3, [%[a_ptr]], #32
\n
"
/* load a10, a11 to q3, q4*/
"fmla v14.4s, v6.4s, v0.s[2]
\n
"
/* out2 = b0 * a00[2], b0 = q6*/
"fmla v17.4s, v6.4s, v0.s[3]
\n
"
/* out3 = b0 * a00[3], b0 = q6*/
"fmla v20.4s, v6.4s, v1.s[0]
\n
"
/* out4 = b0 * a01[0], b0 = q6*/
"fmla v23.4s, v6.4s, v1.s[1]
\n
"
/* out5 = b0 * a01[1], b0 = q6*/
"fmla v26.4s, v6.4s, v1.s[2]
\n
"
/* out6 = b0 * a01[2], b0 = q6*/
"fmla v29.4s, v6.4s, v1.s[3]
\n
"
/* out7 = b0 * a01[3], b0 = q6*/
"fmla v9.4s, v7.4s, v0.s[0]
\n
"
/* out8 = b1 * a00[0], b1 = q7*/
"fmla v12.4s, v7.4s, v0.s[1]
\n
"
/* out9 = b1 * a00[1], b1 = q7*/
"fmla v15.4s, v7.4s, v0.s[2]
\n
"
/* out10 = b1 * a00[2], b1 =
q7*/
"fmla v18.4s, v7.4s, v0.s[3]
\n
"
/* out11 = b1 * a00[3], b1 =
q7*/
"fmla v21.4s, v7.4s, v1.s[0]
\n
"
/* out12 = b1 * a01[0], b1 =
q7*/
"fmla v24.4s, v7.4s, v1.s[1]
\n
"
/* out13 = b1 * a01[1], b1 =
q7*/
"fmla v27.4s, v7.4s, v1.s[2]
\n
"
/* out14 = b1 * a01[2], b1 =
q7*/
"fmla v30.4s, v7.4s, v1.s[3]
\n
"
/* out15 = b1 * a01[3], b1 =
q7*/
"ldp q6, q7, [%[b_ptr]], #32
\n
"
/* load b1, b2 to q6, q7*/
"fmla v10.4s, v4.4s, v0.s[0]
\n
"
/* out16 = b2 * a00[0], b2 =
q4*/
"fmla v13.4s, v4.4s, v0.s[1]
\n
"
/* out17 = b2 * a00[1], b2 =
q4*/
"fmla v16.4s, v4.4s, v0.s[2]
\n
"
/* out18 = b2 * a00[2], b2 =
q4*/
"fmla v19.4s, v4.4s, v0.s[3]
\n
"
/* out19 = b2 * a00[3], b2 =
q4*/
"fmla v22.4s, v4.4s, v1.s[0]
\n
"
/* out20 = b2 * a00[0], b2 =
q4*/
"fmla v25.4s, v4.4s, v1.s[1]
\n
"
/* out21 = b2 * a00[1], b2 =
q4*/
"fmla v28.4s, v4.4s, v1.s[2]
\n
"
/* out22 = b2 * a00[2], b2 =
q4*/
"fmla v31.4s, v4.4s, v1.s[3]
\n
"
/* out23 = b2 * a00[3], b2 =
q4*/
/* unrool 3, tail = 4*/
"fmla v8.4s , v5.4s, v2.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q5*/
"fmla v11.4s , v5.4s, v2.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v5.4s, v2.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q5*/
"fmla v17.4s, v5.4s, v2.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q5*/
"fmla v20.4s, v5.4s, v3.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q5*/
"fmla v23.4s, v5.4s, v3.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v5.4s, v3.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q5*/
"fmla v29.4s, v5.4s, v3.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q5*/
"fmla v9.4s, v6.4s, v2.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q6*/
"fmla v12.4s, v6.4s, v2.s[1]
\n
"
/* out9 = b1 * a10[1], b1 = q6*/
"fmla v15.4s, v6.4s, v2.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v6.4s, v2.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v6.4s, v3.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v6.4s, v3.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v6.4s, v3.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v6.4s, v3.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v7.4s, v2.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v7.4s, v2.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v7.4s, v2.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v7.4s, v2.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v7.4s, v3.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v7.4s, v3.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v7.4s, v3.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v7.4s, v3.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q7*/
"b 11f
\n
"
/* tails==1 final tail*/
"3:
\n
"
/* tail=1*/
"ldr q6, [%[b_ptr]], #16
\n
"
/* load b2 to q6*/
"fmla v8.4s , v4.4s, v0.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q5*/
"fmla v11.4s , v4.4s, v0.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v4.4s, v0.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q5*/
"fmla v17.4s, v4.4s, v0.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q5*/
"fmla v20.4s, v4.4s, v1.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q5*/
"fmla v23.4s, v4.4s, v1.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v4.4s, v1.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q5*/
"fmla v29.4s, v4.4s, v1.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q5*/
"fmla v9.4s, v5.4s, v0.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q6*/
"fmla v12.4s, v5.4s, v0.s[1]
\n
"
/* out9 = b1 * a10[1], b1 = q6*/
"fmla v15.4s, v5.4s, v0.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v5.4s, v0.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v5.4s, v1.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v5.4s, v1.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v5.4s, v1.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v5.4s, v1.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v6.4s, v0.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v6.4s, v0.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v6.4s, v0.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v6.4s, v0.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v6.4s, v1.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v6.4s, v1.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v6.4s, v1.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v6.4s, v1.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q7*/
"b 11f
\n
"
/* tails==2 final tail*/
"4:
\n
"
/* tail = 2*/
"fmla v8.4s , v7.4s, v2.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q5*/
"fmla v11.4s , v7.4s, v2.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v7.4s, v2.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q5*/
"fmla v17.4s, v7.4s, v2.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q5*/
"fmla v20.4s, v7.4s, v3.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q5*/
"fmla v23.4s, v7.4s, v3.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v7.4s, v3.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q5*/
"fmla v29.4s, v7.4s, v3.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q5*/
"fmla v9.4s, v4.4s, v2.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q6*/
"fmla v12.4s, v4.4s, v2.s[1]
\n
"
/* out9 = b1 * a10[1], b1 = q6*/
"fmla v15.4s, v4.4s, v2.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v4.4s, v2.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v4.4s, v3.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v4.4s, v3.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v4.4s, v3.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v4.4s, v3.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v5.4s, v2.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v5.4s, v2.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v5.4s, v2.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v5.4s, v2.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v5.4s, v3.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v5.4s, v3.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v5.4s, v3.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v5.4s, v3.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q7*/
"b 11f
\n
"
/* tails==3 final tail*/
"5:
\n
"
/* tail = 3*/
"ldr q4, [%[b_ptr]], #16
\n
"
/* load b2, b0 to q4*/
"fmla v8.4s , v6.4s, v0.s[0]
\n
"
/* out0 = b0 * a10[0], b0 = q5*/
"fmla v11.4s , v6.4s, v0.s[1]
\n
"
/* out1 = b0 * a10[1], b0 =
q5*/
"fmla v14.4s, v6.4s, v0.s[2]
\n
"
/* out2 = b0 * a10[2], b0 = q5*/
"fmla v17.4s, v6.4s, v0.s[3]
\n
"
/* out3 = b0 * a10[3], b0 = q5*/
"fmla v20.4s, v6.4s, v1.s[0]
\n
"
/* out4 = b0 * a11[0], b0 = q5*/
"fmla v23.4s, v6.4s, v1.s[1]
\n
"
/* out5 = b0 * a11[1], b0 = q5*/
"fmla v26.4s, v6.4s, v1.s[2]
\n
"
/* out6 = b0 * a11[2], b0 = q5*/
"fmla v29.4s, v6.4s, v1.s[3]
\n
"
/* out7 = b0 * a11[3], b0 = q5*/
"fmla v9.4s, v7.4s, v0.s[0]
\n
"
/* out8 = b0 * a10[0], b1 = q6*/
"fmla v12.4s, v7.4s, v0.s[1]
\n
"
/* out9 = b1 * a10[1], b1 = q6*/
"fmla v15.4s, v7.4s, v0.s[2]
\n
"
/* out10 = b1 * a10[2], b1 =
q6*/
"fmla v18.4s, v7.4s, v0.s[3]
\n
"
/* out11 = b1 * a10[3], b1 =
q6*/
"fmla v21.4s, v7.4s, v1.s[0]
\n
"
/* out12 = b1 * a10[0], b1 =
q6*/
"fmla v24.4s, v7.4s, v1.s[1]
\n
"
/* out13 = b1 * a10[1], b1 =
q6*/
"fmla v27.4s, v7.4s, v1.s[2]
\n
"
/* out14 = b1 * a10[2], b1 =
q6*/
"fmla v30.4s, v7.4s, v1.s[3]
\n
"
/* out15 = b1 * a10[3], b1 =
q6*/
"fmla v10.4s, v4.4s, v0.s[0]
\n
"
/* out16 = b2 * a10[0], b2 =
q7*/
"fmla v13.4s, v4.4s, v0.s[1]
\n
"
/* out17 = b2 * a10[0], b2 =
q7*/
"fmla v16.4s, v4.4s, v0.s[2]
\n
"
/* out18 = b2 * a10[0], b2 =
q7*/
"fmla v19.4s, v4.4s, v0.s[3]
\n
"
/* out19 = b2 * a10[0], b2 =
q7*/
"fmla v22.4s, v4.4s, v1.s[0]
\n
"
/* out20 = b2 * a10[0], b2 =
q7*/
"fmla v25.4s, v4.4s, v1.s[1]
\n
"
/* out21 = b2 * a10[0], b2 =
q7*/
"fmla v28.4s, v4.4s, v1.s[2]
\n
"
/* out22 = b2 * a10[0], b2 =
q7*/
"fmla v31.4s, v4.4s, v1.s[3]
\n
"
/* out23 = b2 * a10[0], b2 =
q7*/
"11:
\n
"
/* check if relu */
"cbz %w[relu], 12f
\n
"
/* skip relu */
"movi v2.4s, #0
\n
"
/* for relu*/
"fmax v8.4s, v8.4s, v2.4s
\n
"
/* relu*/
"fmax v9.4s, v9.4s, v2.4s
\n
"
/* relu*/
"fmax v10.4s, v10.4s, v2.4s
\n
"
/* relu*/
"fmax v11.4s, v11.4s, v2.4s
\n
"
/* relu*/
"fmax v12.4s, v12.4s, v2.4s
\n
"
/* relu*/
"fmax v13.4s, v13.4s, v2.4s
\n
"
/* relu*/
"fmax v14.4s, v14.4s, v2.4s
\n
"
/* relu*/
"fmax v15.4s, v15.4s, v2.4s
\n
"
/* relu*/
"fmax v16.4s,v16.4s,v2.4s
\n
"
/* relu*/
"fmax v17.4s,v17.4s,v2.4s
\n
"
/* relu*/
"fmax v18.4s, v18.4s, v2.4s
\n
"
/* relu*/
"fmax v19.4s, v19.4s, v2.4s
\n
"
/* relu*/
"fmax v20.4s, v20.4s, v2.4s
\n
"
/* relu*/
"fmax v21.4s, v21.4s, v2.4s
\n
"
/* relu*/
"fmax v22.4s, v22.4s, v2.4s
\n
"
/* relu*/
"fmax v23.4s, v23.4s, v2.4s
\n
"
/* relu*/
"fmax v24.4s,v24.4s,v2.4s
\n
"
/* relu*/
"fmax v25.4s,v25.4s,v2.4s
\n
"
/* relu*/
"fmax v26.4s, v26.4s, v2.4s
\n
"
/* relu*/
"fmax v27.4s, v27.4s, v2.4s
\n
"
/* relu*/
"fmax v28.4s, v28.4s, v2.4s
\n
"
/* relu*/
"fmax v29.4s, v29.4s, v2.4s
\n
"
/* relu*/
"fmax v30.4s, v30.4s, v2.4s
\n
"
/* relu*/
"fmax v31.4s, v31.4s, v2.4s
\n
"
/* relu*/
"12:
\n
"
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48
\n
"
/* store r0 */
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48
\n
"
/* store r1 */
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48
\n
"
/* store r2 */
"st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48
\n
"
/* store r3 */
"st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48
\n
"
/* store r4 */
"st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48
\n
"
/* store r5 */
"st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48
\n
"
/* store r6 */
"st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48
\n
"
/* store r7 */
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
k
]
"+r"
(
k
),
[
tail
]
"+r"
(
tail
),
[
c_ptr0
]
"+r"
(
c_ptr0
),
[
c_ptr1
]
"+r"
(
c_ptr1
),
[
c_ptr2
]
"+r"
(
c_ptr2
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr4
]
"+r"
(
c_ptr4
),
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
c_ptr6
]
"+r"
(
c_ptr6
),
[
c_ptr7
]
"+r"
(
c_ptr7
)
:
[
bias_ptr
]
"r"
(
bias_local
),
[
relu
]
"r"
(
is_relu
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
*
pout0
++
=
cout0
[
i
];
*
pout1
++
=
cout1
[
i
];
*
pout2
++
=
cout2
[
i
];
*
pout3
++
=
cout3
[
i
];
*
pout4
++
=
cout4
[
i
];
*
pout5
++
=
cout5
[
i
];
*
pout6
++
=
cout6
[
i
];
*
pout7
++
=
cout7
[
i
];
}
}
}
}
}
}
#else //__aarch64__
/**
* \brief gemm with ablock = 6, bblock = 8, output 6x8
* @param A
* @param B
* @param C
* @param M
* @param N
* @param K
* @param threads
* @param workspace
*/
void
sgemm_conv_6x8
(
const
float
*
A_packed
,
const
float
*
B
,
const
float
*
bias
,
float
*
C
,
int
M
,
int
N
,
int
K
,
bool
is_bias
,
bool
is_relu
,
bool
transB
)
{
const
int
threads
=
framework
::
CPUContext
::
Context
()
->
get_thread_num
();
int
l2_size
=
framework
::
CPUContext
::
Context
()
->
get_l2_cache_size
()
/
sizeof
(
float
);
int
l2_cache
=
l2_size
>
0
?
l2_size
:
512
*
1024
;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int
x_block
=
(
l2_cache
-
(
MBLOCK_OTH
*
K
))
/
(
sizeof
(
float
)
*
(
K
+
MBLOCK_OTH
));
x_block
/=
NBLOCK
;
x_block
*=
NBLOCK
;
int
x_num
=
(
N
+
(
x_block
-
1
))
/
x_block
;
x_block
=
(
N
+
x_num
-
1
)
/
x_num
;
x_block
=
(
x_block
+
NBLOCK
-
1
)
/
NBLOCK
;
x_block
*=
NBLOCK
;
x_block
=
x_block
<
NBLOCK
?
NBLOCK
:
x_block
;
int
k_pre
=
((
K
+
KBLOCK
-
1
)
/
KBLOCK
)
-
1
;
int
tail_pre
=
(
K
&
(
KBLOCK
-
1
));
if
(
tail_pre
==
0
)
{
tail_pre
=
KBLOCK
;
}
bool
flag_p_remain
=
false
;
int
remain
=
0
;
//! apanel is pre_compute outside gemm
for
(
unsigned
int
x0
=
0
;
x0
<
N
;
x0
+=
x_block
)
{
unsigned
int
xmax
=
x0
+
x_block
;
if
(
xmax
>
N
)
{
xmax
=
N
;
}
int
bblocks
=
(
xmax
-
x0
+
NBLOCK
-
1
)
/
NBLOCK
;
remain
=
xmax
-
x0
-
(
bblocks
-
1
)
*
NBLOCK
;
if
(
remain
>
0
)
{
flag_p_remain
=
true
;
}
//! load bpanel
float
*
b_pannel
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
K
*
(
xmax
-
x0
)
*
sizeof
(
float
)));
if
(
!
transB
)
{
loadb
(
b_pannel
,
B
,
N
,
0
,
K
,
x0
,
xmax
);
}
#pragma omp parallel for num_threads(threads)
for
(
unsigned
int
y
=
0
;
y
<
M
;
y
+=
MBLOCK_OTH
)
{
unsigned
int
ymax
=
y
+
MBLOCK_OTH
;
if
(
ymax
>
M
)
{
ymax
=
M
;
}
float
*
c_ptr0
=
C
+
y
*
N
+
x0
;
float
*
c_ptr1
=
c_ptr0
+
N
;
float
*
c_ptr2
=
c_ptr1
+
N
;
float
*
c_ptr3
=
c_ptr2
+
N
;
float
*
c_ptr4
=
c_ptr3
+
N
;
float
*
c_ptr5
=
c_ptr4
+
N
;
float
*
pout0
=
c_ptr0
;
float
*
pout1
=
c_ptr1
;
float
*
pout2
=
c_ptr2
;
float
*
pout3
=
c_ptr3
;
float
*
pout4
=
c_ptr4
;
float
*
pout5
=
c_ptr5
;
float
bias_local
[
6
]
=
{
0
};
if
(
is_bias
)
{
bias_local
[
0
]
=
bias
[
y
];
bias_local
[
1
]
=
bias
[
y
+
1
];
bias_local
[
2
]
=
bias
[
y
+
2
];
bias_local
[
3
]
=
bias
[
y
+
3
];
bias_local
[
4
]
=
bias
[
y
+
4
];
bias_local
[
5
]
=
bias
[
y
+
5
];
}
float
cout0
[
NBLOCK
];
float
cout1
[
NBLOCK
];
float
cout2
[
NBLOCK
];
float
cout3
[
NBLOCK
];
float
cout4
[
NBLOCK
];
float
cout5
[
NBLOCK
];
const
float
*
a_ptr_l
=
A_packed
+
y
*
K
;
const
float
*
b_ptr
=
b_pannel
;
for
(
int
xb
=
0
;
xb
<
bblocks
;
xb
++
)
{
if
((
y
+
5
)
>=
ymax
)
{
switch
((
y
+
5
)
-
ymax
)
{
case
4
:
c_ptr1
=
cout1
;
case
3
:
c_ptr2
=
cout2
;
case
2
:
c_ptr3
=
cout3
;
case
1
:
c_ptr4
=
cout4
;
case
0
:
c_ptr5
=
cout5
;
default:
break
;
}
}
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
pout0
=
c_ptr0
;
pout1
=
c_ptr1
;
pout2
=
c_ptr2
;
pout3
=
c_ptr3
;
pout4
=
c_ptr4
;
pout5
=
c_ptr5
;
c_ptr0
=
cout0
;
c_ptr1
=
cout1
;
c_ptr2
=
cout2
;
c_ptr3
=
cout3
;
c_ptr4
=
cout4
;
c_ptr5
=
cout5
;
}
const
float
*
a_ptr
=
a_ptr_l
;
int
tails
=
tail_pre
;
int
k
=
k_pre
;
asm
volatile
(
// sgemm 6x8
"vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements
\n
"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3
\n
"
"pld [%[a_ptr]] @ preload a
\n
"
"vdup.i32 q12,d4[0] @ out40=0
\n
"
"pld [%[b_ptr]] @ preload b
\n
"
"vdup.i32 q13,d4[0] @ out41=0
\n
"
"pld [%[a_ptr], #64] @ preload a
\n
"
"vdup.i32 q14,d4[1] @ out50=0
\n
"
"pld [%[b_ptr], #64] @ preload b
\n
"
"vdup.i32 q15,d4[1] @ out51=0
\n
"
"pld [%[a_ptr], #128] @ preload a
\n
"
"vdup.i32 q4, d2[0] @ out00=0
\n
"
"pld [%[b_ptr], #128] @ preload b
\n
"
"vdup.i32 q5, d2[0] @ out01=0
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vdup.i32 q6, d2[1] @ out10=0
\n
"
"pld [%[a_ptr], #192] @ preload a
\n
"
"vdup.i32 q7, d2[1] @ out11=0
\n
"
"pld [%[b_ptr], #192] @ preload a
\n
"
"vdup.i32 q8, d3[0] @ out20=0
\n
"
"pld [%[a_ptr], #256] @ preload a
\n
"
"vdup.i32 q9, d3[0] @ out21=0
\n
"
"pld [%[b_ptr], #256] @ preload a
\n
"
"vdup.i32 q10,d3[1] @ out30=0
\n
"
"pld [%[b_ptr], #320] @ preload b
\n
"
"vdup.i32 q11,d3[1] @ out31=0
\n
"
"pld [%[b_ptr], #384] @ preload b
\n
"
"cmp %[k], #0 @ check weather k is "
"bigger than 0
\n
"
"beq 0f @ jump to tail
\n
"
"1: @ main loop for k
\n
"
/* Unroll 0*/
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next "
"a0, a1
\n
"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3
\n
"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5
\n
"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
/* Unroll 1 */
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1
\n
"
/*"pld [%[a_ptr], #64] @ preload a\n"*/
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3
\n
"
/*"pld [%[b_ptr], #192]\n"*/
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1
\n
"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3
\n
"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1
\n
"
/* Unroll 2 */
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3
\n
"
/*"pld [%[a_ptr], #240] @ preload\n"*/
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1
\n
"
/*"pld [%[b_ptr], #208]\n"*/
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3
\n
"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5
\n
"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
/* Unroll 3 */
"vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1
\n
"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3
\n
"
"vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3
\n
"
"subs %[k], %[k], #1 @ k--
\n
"
"vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5
\n
"
"bne 1b @ jump to main loop
\n
"
"0: @ process tail
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"beq 3f @ jump to tail = 1
\n
"
/* Unroll 0*/
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1
\n
"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3
\n
"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5
\n
"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"beq 4f @ jump to tail==2
\n
"
/* Unroll 1*/
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1
\n
"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3
\n
"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"beq 5f @ jump to tail==3
\n
"
/* Unroll 2 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1
\n
"
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5
\n
"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1
\n
"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3
\n
"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5
\n
"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
/* Unroll 3*/
"vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5
\n
"
"vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5
\n
"
"b 2f
\n
"
/* tails==1 final tail*/
"3: @ tail=1
\n
"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5
\n
"
"vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1
\n
"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2
\n
"
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5
\n
"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5
\n
"
"b 2f @ jump to end
\n
"
/* tails==2 final tail*/
"4: @ tail == 2
\n
"
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5
\n
"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5
\n
"
"b 2f @ jump to end
\n
"
/* tails==3 final tail*/
"5: @ tail=3
\n
"
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5
\n
"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4
\n
"
"vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5
\n
"
"vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3
\n
"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4
\n
"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5
\n
"
"2: @ check relu
\n
"
"cmp %[relu], #0 @ check if has relu
\n
"
"ble 6f @ skip relu if relu <= 0
\n
"
"vmov.u32 q0, #0 @ for relu
\n
"
"vmax.f32 q4, q4, q0 @ for relu
\n
"
"vmax.f32 q5, q5, q0 @ for relu
\n
"
"vmax.f32 q6, q6, q0 @ for relu
\n
"
"vmax.f32 q7, q7, q0 @ for relu
\n
"
"vmax.f32 q8, q8, q0 @ for relu
\n
"
"vmax.f32 q9, q9, q0 @ for relu
\n
"
"vmax.f32 q10, q10, q0 @ for relu
\n
"
"vmax.f32 q11, q11, q0 @ for relu
\n
"
"vmax.f32 q12, q12, q0 @ for relu
\n
"
"vmax.f32 q13, q13, q0 @ for relu
\n
"
"vmax.f32 q14, q14, q0 @ for relu
\n
"
"vmax.f32 q15, q15, q0 @ for relu
\n
"
"6: @ store result
\n
"
"vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0
\n
"
"vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1
\n
"
"vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2
\n
"
"vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3
\n
"
"vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4
\n
"
"vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
c_ptr0
]
"+r"
(
c_ptr0
),
[
c_ptr1
]
"+r"
(
c_ptr1
),
[
c_ptr2
]
"+r"
(
c_ptr2
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr4
]
"+r"
(
c_ptr4
),
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
k
]
"+r"
(
k
),
[
tails
]
"+r"
(
tails
)
:
[
bias_ptr
]
"r"
(
bias_local
),
[
relu
]
"r"
(
is_relu
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
*
pout0
++
=
cout0
[
i
];
*
pout1
++
=
cout1
[
i
];
*
pout2
++
=
cout2
[
i
];
*
pout3
++
=
cout3
[
i
];
*
pout4
++
=
cout4
[
i
];
*
pout5
++
=
cout5
[
i
];
}
}
}
}
}
}
void
sgemm_conv_4x8
(
const
float
*
A_packed
,
const
float
*
B
,
const
float
*
bias
,
float
*
C
,
int
M
,
int
N
,
int
K
,
bool
is_bias
,
bool
is_relu
,
bool
transB
)
{
const
int
threads
=
framework
::
CPUContext
::
Context
()
->
get_thread_num
();
int
l2_size
=
framework
::
CPUContext
::
Context
()
->
get_l2_cache_size
()
/
sizeof
(
float
);
int
l2_cache
=
l2_size
>
0
?
l2_size
:
512
*
1024
;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int
x_block
=
(
l2_cache
-
(
MBLOCK_A73
*
K
))
/
(
sizeof
(
float
)
*
(
K
+
MBLOCK_A73
));
x_block
/=
NBLOCK
;
x_block
*=
NBLOCK
;
int
x_num
=
(
N
+
(
x_block
-
1
))
/
x_block
;
x_block
=
(
N
+
x_num
-
1
)
/
x_num
;
x_block
=
(
x_block
+
NBLOCK
-
1
)
/
NBLOCK
;
x_block
*=
NBLOCK
;
x_block
=
x_block
<
NBLOCK
?
NBLOCK
:
x_block
;
int
k_pre
=
((
K
+
KBLOCK
-
1
)
/
KBLOCK
)
-
1
;
int
tail_pre
=
(
K
&
(
KBLOCK
-
1
));
if
(
tail_pre
==
0
)
{
tail_pre
=
KBLOCK
;
}
bool
flag_p_remain
=
false
;
int
remain
=
0
;
//! apanel is pre_compute outside gemm
for
(
unsigned
int
x0
=
0
;
x0
<
N
;
x0
+=
x_block
)
{
unsigned
int
xmax
=
x0
+
x_block
;
if
(
xmax
>
N
)
{
xmax
=
N
;
}
int
bblocks
=
(
xmax
-
x0
+
NBLOCK
-
1
)
/
NBLOCK
;
remain
=
xmax
-
x0
-
(
bblocks
-
1
)
*
NBLOCK
;
if
(
remain
>
0
)
{
flag_p_remain
=
true
;
}
//! load bpanel
float
*
b_pannel
=
static_cast
<
float
*>
(
framework
::
CPUContext
::
Context
()
->
get_work_space
(
K
*
(
xmax
-
x0
)
*
sizeof
(
float
)));
if
(
!
transB
)
{
loadb
(
b_pannel
,
B
,
N
,
0
,
K
,
x0
,
xmax
);
}
#pragma omp parallel for num_threads(threads)
for
(
unsigned
int
y
=
0
;
y
<
M
;
y
+=
MBLOCK_A73
)
{
unsigned
int
ymax
=
y
+
MBLOCK_A73
;
if
(
ymax
>
M
)
{
ymax
=
M
;
}
float
cout0
[
NBLOCK
];
float
cout1
[
NBLOCK
];
float
cout2
[
NBLOCK
];
float
cout3
[
NBLOCK
];
float
bias_local
[
4
]
=
{
0
};
if
(
is_bias
)
{
bias_local
[
0
]
=
bias
[
y
];
bias_local
[
1
]
=
bias
[
y
+
1
];
bias_local
[
2
]
=
bias
[
y
+
2
];
bias_local
[
3
]
=
bias
[
y
+
3
];
}
float
*
c_ptr0
=
C
+
y
*
N
+
x0
;
float
*
c_ptr1
=
c_ptr0
+
N
;
float
*
c_ptr2
=
c_ptr1
+
N
;
float
*
c_ptr3
=
c_ptr2
+
N
;
float
*
pout0
=
c_ptr0
;
float
*
pout1
=
c_ptr1
;
float
*
pout2
=
c_ptr2
;
float
*
pout3
=
c_ptr3
;
const
float
*
a_ptr_l
=
A_packed
+
y
*
K
;
const
float
*
b_ptr
=
b_pannel
;
for
(
int
xb
=
0
;
xb
<
bblocks
;
xb
++
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
case
2
:
c_ptr1
=
cout1
;
case
1
:
c_ptr2
=
cout1
;
case
0
:
c_ptr3
=
cout1
;
default:
break
;
}
}
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
pout0
=
c_ptr0
;
pout1
=
c_ptr1
;
pout2
=
c_ptr2
;
pout3
=
c_ptr3
;
c_ptr0
=
cout0
;
c_ptr1
=
cout1
;
c_ptr2
=
cout2
;
c_ptr3
=
cout3
;
}
const
float
*
a_ptr
=
a_ptr_l
;
int
tails
=
tail_pre
;
int
k
=
k_pre
;
asm
volatile
(
"vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias
\n
"
"vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3
\n
"
"vdup.32 q8, d4[0] @ add bias to out00
\n
"
"pld [%[a_ptr]] @ preload a, 64byte
\n
"
"vdup.32 q9, d4[0] @ add bias to out01
\n
"
"pld [%[b_ptr]] @ preload b
\n
"
"vdup.32 q10, d4[1] @ add bias to out10
\n
"
"pld [%[a_ptr], #64] @ preload a
\n
"
"vdup.32 q11, d4[1] @ add bias to out11
\n
"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1
\n
"
"vdup.32 q12, d5[0] @ add bias to out20
\n
"
"pld [%[b_ptr], #64] @ preload b
\n
"
"vdup.32 q13, d5[0] @ add bias to out21
\n
"
"pld [%[a_ptr], #128] @ preload a
\n
"
"vdup.32 q14, d5[1] @ add bias to out30
\n
"
"pld [%[b_ptr], #128] @ preload b
\n
"
"vdup.32 q15, d5[1] @ add bias to out31
\n
"
"pld [%[b_ptr], #192] @ preload b
\n
"
"cmp %[k], #0 @ check weather k is "
"bigger than 0
\n
"
"beq 0f @ jump to tail
\n
"
"1: @ main loop for k
\n
"
/* Unroll 0*/
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2
\n
"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3
\n
"
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3
\n
"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2
\n
"
/* Unroll 1 */
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0
\n
"
"pld [%[b_ptr], #64] @ preload b
\n
"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3
\n
"
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2
\n
"
/* Unroll 2 */
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0
\n
"
"vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3
\n
"
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3
\n
"
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2
\n
"
/* Unroll 3 */
"vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0
\n
"
"pld [%[a_ptr], #64] @ preload a
\n
"
"vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3
\n
"
"subs %[k], %[k], #1 @ k--
\n
"
"bne 1b @ jump to main loop
\n
"
"0: @ process tail
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"beq 3f @ jump to tail = 1
\n
"
/* Unroll 0*/
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2
\n
"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0
\n
"
// b1*a1
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3
\n
"
"beq 4f @ jump to tail==2
\n
"
/* Unroll 1 */
"vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2
\n
"
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0
\n
"
// b6*a2
"vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3
\n
"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1
\n
"
"subs %[tails], %[tails], #1 @ tail--
\n
"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0
\n
"
"vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1
\n
"
"vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2
\n
"
"vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3
\n
"
"beq 5f @ jump to tail==3
\n
"
/* Unroll 2 */
"vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2
\n
"
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0
\n
"
// b11
// *
// a3
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3
\n
"
/* Unroll 3 */
"vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0
\n
"
// b16
// *
// a4
"vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3
\n
"
"b 2f
\n
"
/* tails==1 final tail */
"3: @ tail=1
\n
"
"vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3
\n
"
/*aptr - 16 */
"sub %[a_ptr], %[a_ptr], #16 @ tail--
\n
"
"b 2f @ jump to end
\n
"
/* tails==2 final tail*/
"4: @ tail == 2
\n
"
"vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3
\n
"
"b 2f @ jump to end
\n
"
/* tails==3 final tail*/
"5: @ tail=3
\n
"
"vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0
\n
"
"vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1
\n
"
"vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2
\n
"
"vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3
\n
"
"vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0
\n
"
"vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1
\n
"
"vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2
\n
"
"vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3
\n
"
/*aptr - 16*/
"sub %[a_ptr], %[a_ptr], #16 @ tail--
\n
"
"2: @ check relu
\n
"
"cmp %[relu], #0 @ check if has relu
\n
"
"ble 6f @ skip relu if relu <= 0
\n
"
"vmov.u32 q0, #0 @ for relu
\n
"
"vmax.f32 q8, q8, q0 @ for relu
\n
"
"vmax.f32 q9, q9, q0 @ for relu
\n
"
"vmax.f32 q10, q10, q0 @ for relu
\n
"
"vmax.f32 q11, q11, q0 @ for relu
\n
"
"vmax.f32 q12, q12, q0 @ for relu
\n
"
"vmax.f32 q13, q13, q0 @ for relu
\n
"
"vmax.f32 q14, q14, q0 @ for relu
\n
"
"vmax.f32 q15, q15, q0 @ for relu
\n
"
"6: @ store result
\n
"
"vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0
\n
"
"vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1
\n
"
"vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2
\n
"
"vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
c_ptr0
]
"+r"
(
c_ptr0
),
[
c_ptr1
]
"+r"
(
c_ptr1
),
[
c_ptr2
]
"+r"
(
c_ptr2
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
k
]
"+r"
(
k
),
[
tails
]
"+r"
(
tails
)
:
[
bias_ptr
]
"r"
(
bias_local
),
[
relu
]
"r"
(
is_relu
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
if
(
flag_p_remain
&&
(
xb
==
bblocks
-
1
))
{
for
(
int
i
=
0
;
i
<
remain
;
++
i
)
{
*
pout0
++
=
cout0
[
i
];
*
pout1
++
=
cout1
[
i
];
*
pout2
++
=
cout2
[
i
];
*
pout3
++
=
cout3
[
i
];
}
}
}
}
}
}
#endif //__aarch64__
/// a: m*k b: k*n c: m*n
void
sgemm_prepack
(
const
float
*
A_packed
,
const
float
*
B
,
const
float
*
bias
,
float
*
C
,
int
M
,
int
N
,
int
K
,
bool
is_bias
,
bool
is_relu
,
bool
is_transB
,
ARMArch
arch
)
{
#ifdef __aarch64__
sgemm_conv_8x12
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
);
#else // armv7
if
(
arch
==
A73
)
{
sgemm_conv_4x8
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
);
}
else
{
sgemm_conv_6x8
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
);
}
#endif // arm64
}
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
#endif // CONV_OP
#endif // __ARM_NEON__
src/operators/math/gemm/gemm1x1s1.h
0 → 100644
浏览文件 @
74a309cb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONV_OP
#pragma once
#include "framework/tensor.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
#ifdef __aarch64__
const
int
MBLOCK
=
8
;
const
int
NBLOCK
=
12
;
const
int
KBLOCK
=
4
;
inline
int
get_hblock
(
ARMArch
arch
)
{
return
MBLOCK
;
}
#else
const
int
MBLOCK_A73
=
4
;
const
int
MBLOCK_OTH
=
6
;
const
int
NBLOCK
=
8
;
const
int
KBLOCK
=
4
;
inline
int
get_hblock
(
ARMArch
arch
)
{
if
(
arch
==
A73
)
{
return
MBLOCK_A73
;
}
else
{
return
MBLOCK_OTH
;
}
}
#endif // __aarch64__
void
gemm1x1s1_transform_weight
(
const
framework
::
Tensor
&
weight
,
const
framework
::
Tensor
&
output
,
framework
::
Tensor
*
trans_weight
,
const
int
group
,
ARMArch
arch
);
void
sgemm_prepack
(
const
float
*
A_packed
,
const
float
*
B
,
const
float
*
bias
,
float
*
C
,
int
M
,
int
N
,
int
K
,
bool
is_bias
,
bool
is_relu
,
bool
is_transB
,
ARMArch
arch
);
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
#endif // CONV_OP
src/operators/op_param.h
浏览文件 @
74a309cb
...
@@ -467,6 +467,7 @@ class ConvParam : public OpParam {
...
@@ -467,6 +467,7 @@ class ConvParam : public OpParam {
EXEC_SLIDINGWINDOW3x3_FLOAT
,
EXEC_SLIDINGWINDOW3x3_FLOAT
,
EXEC_SLIDINGWINDOW5x5_FLOAT
,
EXEC_SLIDINGWINDOW5x5_FLOAT
,
EXEC_SLIDINGWINDOW7x7_FLOAT
,
EXEC_SLIDINGWINDOW7x7_FLOAT
,
EXEC_GEMM1x1s1_FLOAT
,
};
};
ExecMode
&
ExecMode
()
const
{
return
exec_mode_
;
}
ExecMode
&
ExecMode
()
const
{
return
exec_mode_
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录