Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8fe8edf4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8fe8edf4
编写于
10月 09, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add fp16 mk8 16x12 matmul algo
GitOrigin-RevId: d9978fa5890e3feb8f2383fce4fa3b4f00cd93bf
上级
21a975f8
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
757 addition
and
2 deletion
+757
-2
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+55
-0
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+11
-0
dnn/src/aarch64/matrix_mul/asm/common.h
dnn/src/aarch64/matrix_mul/asm/common.h
+80
-0
dnn/src/aarch64/matrix_mul/fp16/kernel_mk8_16x12.h
dnn/src/aarch64/matrix_mul/fp16/kernel_mk8_16x12.h
+160
-0
dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc
dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc
+321
-0
dnn/src/aarch64/matrix_mul/fp16/strategy.h
dnn/src/aarch64/matrix_mul/fp16/strategy.h
+3
-0
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_16x12.cpp
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_16x12.cpp
+107
-0
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+3
-2
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+1
-0
dnn/test/aarch64/matrix_mul.cpp
dnn/test/aarch64/matrix_mul.cpp
+14
-0
未找到文件。
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
8fe8edf4
...
...
@@ -352,6 +352,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
return
kern_mk8_8x8
;
}
/* ==================== F16_MK8_16x12x1 algo ====================*/
bool
MatrixMulImpl
::
AlgoF16MK8_16x12x1
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
C_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
B_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
A_type
==
dtype
::
Float16
()
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK8
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
;
}
size_t
MatrixMulImpl
::
AlgoF16MK8_16x12x1
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoF16MK8_16x12x1::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
aarch64
::
matmul
::
hgemm_mk8_16x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
hgemm_mk8_16x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
return
0
;
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF16MK8_16x12x1
::
get_kern
(
const
KernSizeParam
&
)
const
{
auto
kern_mk8_16x12x1
=
[](
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoF16MK8_16x12x1::get_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
dt_float16
>
(),
Bptr
=
kern_param
.
B
<
dt_float16
>
();
auto
Cptr
=
kern_param
.
C
<
dt_float16
>
();
aarch64
::
matmul
::
hgemm_mk8_16x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
hgemm_mk8_16x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
};
return
kern_mk8_16x12x1
;
}
#endif
#if MGB_ENABLE_DOT
...
...
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
8fe8edf4
...
...
@@ -86,6 +86,17 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_F16_MK8_8X8
)
};
class
MatrixMulImpl
::
AlgoF16MK8_16x12x1
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
const
char
*
name
()
const
override
{
return
"AARCH64_F16_MK8_16X12X1"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_OVERRIDE_MATMUL_DESC
(
16
,
12
,
1
,
2
,
AlgoDataType
::
FLOAT16
,
MK8
);
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_F16_MK8_16X12X1
);
};
#endif
#if MGB_ENABLE_DOT
...
...
dnn/src/aarch64/matrix_mul/asm/common.h
浏览文件 @
8fe8edf4
...
...
@@ -1142,6 +1142,29 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, T* out
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_2x8_2_h
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
2
,
"interleave_2x8_2_s only support size == 2"
);
asm
volatile
(
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[inptr0]], #64
\n
"
"ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[inptr0]], #64
\n
"
"ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [%[inptr1]], #64
\n
"
"ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [%[inptr1]], #64
\n
"
"stp q0, q8, [%[outptr]]
\n
"
"stp q1, q9, [%[outptr], #32]
\n
"
"stp q2, q10, [%[outptr], #64]
\n
"
"stp q3, q11, [%[outptr], #96]
\n
"
"stp q4, q12, [%[outptr], #128]
\n
"
"stp q5, q13, [%[outptr], #160]
\n
"
"stp q6, q14, [%[outptr], #192]
\n
"
"stp q7, q15, [%[outptr], #224]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_1x4_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_1x4_4_s only support size == 4"
);
...
...
@@ -1154,6 +1177,20 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) {
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_1x8_2_h
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
2
,
"interleave_1x8_2_s only support size == 2"
);
asm
volatile
(
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[inptr0]], #64
\n
"
"ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[inptr0]], #64
\n
"
"st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[outptr]], #64
\n
"
"st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%[outptr]], #64
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_4x8_2_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1548,6 +1585,49 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_1x12_2_h
(
const
T
*&
inptr
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
2
,
"transpose_1x12_2_s only support sizeof(T) == 2"
);
asm
volatile
(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr]], #64
\n
"
"ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr]], #64
\n
"
"ld4 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[inptr]], #64
\n
"
"uzp1 v12.8h, v0.8h, v4.8h
\n
"
"st1 {v12.8h}, [%[outptr]], #16
\n
"
// line[0][0-7]
"uzp1 v14.8h, v8.8h, v9.8h
\n
"
"st1 {v14.d}[0], [%[outptr]], #8
\n
"
// line[0][8-11]
"uzp2 v13.8h, v0.8h, v4.8h
\n
"
"st1 {v13.8h}, [%[outptr]], #16
\n
"
// line[1][0-7]
"uzp2 v15.8h, v8.8h, v9.8h
\n
"
"st1 {v15.d}[0], [%[outptr]], #8
\n
"
// line[1][8-11]
"uzp1 v16.8h, v1.8h, v5.8h
\n
"
"st1 {v16.8h}, [%[outptr]], #16
\n
"
// line[2][0-7]
"st1 {v14.d}[1], [%[outptr]], #8
\n
"
// line[2][8-11]
"uzp2 v17.8h, v1.8h, v5.8h
\n
"
"st1 {v17.8h}, [%[outptr]], #16
\n
"
// line[3][0-7]
"st1 {v15.d}[1], [%[outptr]], #8
\n
"
// line[3][8-11]
"uzp1 v18.8h, v2.8h, v6.8h
\n
"
"st1 {v18.8h}, [%[outptr]], #16
\n
"
// line[4][0-7]
"uzp1 v19.8h, v10.8h, v11.8h
\n
"
"st1 {v19.d}[0], [%[outptr]], #8
\n
"
// line[4][8-11]
"uzp2 v20.8h, v2.8h, v6.8h
\n
"
"st1 {v20.8h}, [%[outptr]], #16
\n
"
// line[5][0-7]
"uzp2 v21.8h, v10.8h, v11.8h
\n
"
"st1 {v21.d}[0], [%[outptr]], #8
\n
"
// line[5][8-11]
"uzp1 v22.8h, v3.8h, v7.8h
\n
"
"st1 {v22.8h}, [%[outptr]], #16
\n
"
// line[6][0-7]
"st1 {v19.d}[1], [%[outptr]], #8
\n
"
// line[6][8-11]
"uzp2 v23.8h, v3.8h, v7.8h
\n
"
"st1 {v23.8h}, [%[outptr]], #16
\n
"
// line[7][0-7]
"st1 {v21.d}[1], [%[outptr]], #8
\n
"
// line[7][8-11]
:
[
inptr
]
"+r"
(
inptr
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_1x4_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_1x4_4_s only support sizeof(T) == 4"
);
...
...
dnn/src/aarch64/matrix_mul/fp16/kernel_mk8_16x12.h
0 → 100644
浏览文件 @
8fe8edf4
#pragma once
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace
megdnn
{
namespace
aarch64
{
struct
matmul_mk8_16x12
{
template
<
size_t
M_BLOCK
,
size_t
N_BLOCK
>
static
void
kern
(
const
dt_float16
*
packedA
,
const
dt_float16
*
packedB
,
int
K
,
dt_float16
*
out
,
int
LDC
,
bool
is_first_k
);
static
void
hgemm_16x12_pack_A
(
dt_float16
*
outptr
,
const
dt_float16
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
y0
%
8
==
0
&&
ymax
%
8
==
0
,
"M must be time of 8"
);
megdnn_assert
(
k0
%
8
==
0
&&
kmax
%
8
==
0
,
"K must be time of 8"
);
constexpr
int
PACK_SIZE_128
=
16
*
8
;
constexpr
int
PACK_SIZE_64
=
8
*
8
;
constexpr
int
PACK_C_SIZE
=
8
;
int
y
=
y0
;
for
(;
y
+
15
<
ymax
;
y
+=
16
)
{
const
dt_float16
*
inptr0
=
inptr
+
y
/
PACK_C_SIZE
*
ldin
+
k0
;
const
dt_float16
*
inptr1
=
inptr0
+
ldin
;
prefetch_4x
(
inptr0
);
prefetch_4x
(
inptr1
);
for
(
int
k
=
k0
;
k
<
kmax
;
k
+=
8
)
{
interleave_2x8_2_h
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
PACK_SIZE_128
;
}
}
for
(;
y
<
ymax
;
y
+=
8
)
{
const
dt_float16
*
inptr0
=
inptr
+
y
/
PACK_C_SIZE
*
ldin
+
k0
;
prefetch_4x
(
inptr0
);
for
(
int
k
=
k0
;
k
<
kmax
;
k
+=
8
)
{
interleave_1x8_2_h
(
inptr0
,
outptr
);
outptr
+=
PACK_SIZE_64
;
}
}
}
static
void
hgemm_16x12_pack_B
(
dt_float16
*
out
,
const
dt_float16
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
k0
%
8
==
0
&&
kmax
%
8
==
0
,
"K must be time of 8"
);
dt_float16
tmpbuff
[
96
]
=
{
static_cast
<
dt_float16
>
(
0.0
)};
constexpr
int
PACK_C_SIZE
=
8
;
int
ksize
=
kmax
-
k0
;
int
ksize12
=
ksize
*
12
;
dt_float16
*
outptr_base
=
out
;
for
(
int
k
=
k0
;
k
<
kmax
;
k
+=
8
)
{
const
dt_float16
*
inptr
=
in
+
k
/
PACK_C_SIZE
*
ldin
+
x0
*
PACK_C_SIZE
;
prefetch_3x
(
inptr
);
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
12
<=
xmax
;
x
+=
12
)
{
auto
outptr_interleave
=
outptr
;
transpose_1x12_2_h
(
inptr
,
outptr_interleave
);
outptr
+=
ksize12
;
}
if
(
x
<
xmax
)
{
std
::
memcpy
(
tmpbuff
,
inptr
,
sizeof
(
dt_float16
)
*
(
xmax
-
x
)
*
PACK_C_SIZE
);
auto
outptr_interleave
=
outptr
;
inptr
=
tmpbuff
;
transpose_1x12_2_h
(
inptr
,
outptr_interleave
);
}
outptr_base
+=
12
*
8
;
}
}
};
#define M_BLOCK 1
#define N_BLOCK 1
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 2
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 3
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 4
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 5
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 6
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 7
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 8
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 9
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 10
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 11
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 12
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#undef M_BLOCK
#define M_BLOCK 2
#define N_BLOCK 1
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 2
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 3
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 4
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 5
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 6
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 7
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 8
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 9
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 10
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 11
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#define N_BLOCK 12
#include "mk8_16x12_kern.inc"
#undef N_BLOCK
#undef M_BLOCK
}
// namespace aarch64
}
// namespace megdnn
#endif
\ No newline at end of file
dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc
0 → 100644
浏览文件 @
8fe8edf4
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#ifndef _STR
#define _STR(X) #X
#endif
#ifndef STR
#define STR(X) _STR(X)
#endif
template
<>
void
matmul_mk8_16x12
::
kern
<
M_BLOCK
,
N_BLOCK
>
(
const
dt_float16
*
packedA
,
const
dt_float16
*
packedB
,
int
K
,
dt_float16
*
out
,
int
LDC
,
bool
is_first_k
)
{
#define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n"
#define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n"
// clang-format off
#define IF_MN_GT(M, N, INSTRUC) \
".if "
STR
(
M_BLOCK
)
" > "
#M "\n" \
".if "
STR
(
N_BLOCK
)
" > "
#N "\n" \
INSTRUC
\
".endif
\n
"
\
".endif
\n
"
const
dt_float16
*
a_ptr
=
packedA
;
const
dt_float16
*
b_ptr
=
packedB
;
dt_float16
*
outptr0
=
out
;
dt_float16
*
outptr1
=
out
+
LDC
;
int
oddK
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
asm
volatile
(
"cmp %w[is_first_k], #1
\n
"
"beq 1f
\n
"
IF_M_GT
(
0
,
"mov x1, %[outptr0]
\n
"
)
IF_M_GT
(
1
,
"mov x2, %[outptr1]
\n
"
)
IF_MN_GT
(
0
,
0
,
"ld1
{
v8.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
1
,
"ld1
{
v9.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
2
,
"ld1
{
v10.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
3
,
"ld1
{
v11.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
4
,
"ld1
{
v12.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
5
,
"ld1
{
v13.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
6
,
"ld1
{
v14.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
7
,
"ld1
{
v15.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
8
,
"ld1
{
v16.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
9
,
"ld1
{
v17.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
10
,
"ld1
{
v18.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
0
,
11
,
"ld1
{
v19.8h
}
, [x1], #16
\n
"
)
IF_MN_GT
(
1
,
0
,
"ld1
{
v20.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
1
,
"ld1
{
v21.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
2
,
"ld1
{
v22.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
3
,
"ld1
{
v23.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
4
,
"ld1
{
v24.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
5
,
"ld1
{
v25.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
6
,
"ld1
{
v26.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
7
,
"ld1
{
v27.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
8
,
"ld1
{
v28.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
9
,
"ld1
{
v29.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
10
,
"ld1
{
v30.8h
}
, [x2], #16
\n
"
)
IF_MN_GT
(
1
,
11
,
"ld1
{
v31.8h
}
, [x2], #16
\n
"
)
IF_M_GT
(
0
,
"ld1
{
v0.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_N_GT
(
0
,
"ld1
{
v2.8h
}
, [%[b_ptr]], #16
\n
"
)
"b 2f
\n
"
"1:
\n
"
IF_MN_GT
(
0
,
0
,
"eor v8.16b, v8.16b, v8.16b
\n
"
)
IF_MN_GT
(
0
,
1
,
"eor v9.16b, v9.16b, v9.16b
\n
"
)
IF_MN_GT
(
0
,
2
,
"eor v10.16b, v10.16b, v10.16b
\n
"
)
"prfm pstl1keep, [%[outptr0]]
\n
"
IF_MN_GT
(
0
,
3
,
"eor v11.16b, v11.16b, v11.16b
\n
"
)
IF_MN_GT
(
0
,
4
,
"eor v12.16b, v12.16b, v12.16b
\n
"
)
IF_MN_GT
(
0
,
5
,
"eor v13.16b, v13.16b, v13.16b
\n
"
)
"prfm pstl1keep, [%[outptr1]]
\n
"
IF_MN_GT
(
0
,
6
,
"eor v14.16b, v14.16b, v14.16b
\n
"
)
IF_MN_GT
(
0
,
7
,
"eor v15.16b, v15.16b, v15.16b
\n
"
)
IF_MN_GT
(
0
,
8
,
"eor v16.16b, v16.16b, v16.16b
\n
"
)
IF_N_GT
(
0
,
"ld1
{
v2.8h
}
, [%[b_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
9
,
"eor v17.16b, v17.16b, v17.16b
\n
"
)
IF_MN_GT
(
0
,
10
,
"eor v18.16b, v18.16b, v18.16b
\n
"
)
IF_MN_GT
(
0
,
11
,
"eor v19.16b, v19.16b, v19.16b
\n
"
)
IF_MN_GT
(
1
,
0
,
"eor v20.16b, v20.16b, v20.16b
\n
"
)
IF_MN_GT
(
1
,
1
,
"eor v21.16b, v21.16b, v21.16b
\n
"
)
IF_M_GT
(
0
,
"ld1
{
v0.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
1
,
2
,
"eor v22.16b, v22.16b, v22.16b
\n
"
)
IF_MN_GT
(
1
,
3
,
"eor v23.16b, v23.16b, v23.16b
\n
"
)
IF_MN_GT
(
1
,
4
,
"eor v24.16b, v24.16b, v24.16b
\n
"
)
IF_MN_GT
(
1
,
5
,
"eor v25.16b, v25.16b, v25.16b
\n
"
)
IF_MN_GT
(
1
,
6
,
"eor v26.16b, v26.16b, v26.16b
\n
"
)
IF_MN_GT
(
1
,
7
,
"eor v27.16b, v27.16b, v27.16b
\n
"
)
IF_MN_GT
(
1
,
8
,
"eor v28.16b, v28.16b, v28.16b
\n
"
)
IF_MN_GT
(
1
,
9
,
"eor v29.16b, v29.16b, v29.16b
\n
"
)
IF_MN_GT
(
1
,
10
,
"eor v30.16b, v30.16b, v30.16b
\n
"
)
IF_MN_GT
(
1
,
11
,
"eor v31.16b, v31.16b, v31.16b
\n
"
)
"2:
\n
"
"cmp %w[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"ld1
{
v3.8h
}
, [%[b_ptr]], #16
\n
"
IF_MN_GT
(
0
,
0
,
"fmla v8.8h, v0.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
0
,
1
,
"fmla v9.8h, v0.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
0
,
2
,
"fmla v10.8h, v0.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
0
,
3
,
"fmla v11.8h, v0.8h, v2.h[3]
\n
"
)
IF_M_GT
(
1
,
"ld1
{
v1.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
4
,
"fmla v12.8h, v0.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
0
,
5
,
"fmla v13.8h, v0.8h, v2.h[5]
\n
"
)
IF_MN_GT
(
0
,
6
,
"fmla v14.8h, v0.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
0
,
7
,
"fmla v15.8h, v0.8h, v2.h[7]
\n
"
)
IF_M_GT
(
0
,
"ld1
{
v5.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
8
,
"fmla v16.8h, v0.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
0
,
9
,
"fmla v17.8h, v0.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
0
,
10
,
"fmla v18.8h, v0.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
0
,
11
,
"fmla v19.8h, v0.8h, v3.h[3]
\n
"
)
IF_MN_GT
(
1
,
0
,
"fmla v20.8h, v1.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
1
,
1
,
"fmla v21.8h, v1.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
1
,
2
,
"fmla v22.8h, v1.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
1
,
3
,
"fmla v23.8h, v1.8h, v2.h[3]
\n
"
)
IF_MN_GT
(
1
,
4
,
"fmla v24.8h, v1.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
1
,
5
,
"fmla v25.8h, v1.8h, v2.h[5]
\n
"
)
"ld1
{
v4.8h
}
, [%[b_ptr]], #16
\n
"
IF_MN_GT
(
1
,
6
,
"fmla v26.8h, v1.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
1
,
7
,
"fmla v27.8h, v1.8h, v2.h[7]
\n
"
)
IF_MN_GT
(
1
,
8
,
"fmla v28.8h, v1.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
1
,
9
,
"fmla v29.8h, v1.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
1
,
10
,
"fmla v30.8h, v1.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
1
,
11
,
"fmla v31.8h, v1.8h, v3.h[3]
\n
"
)
IF_M_GT
(
1
,
"ld1
{
v6.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
0
,
"fmla v8.8h, v5.8h, v3.h[4]
\n
"
)
IF_MN_GT
(
0
,
1
,
"fmla v9.8h, v5.8h, v3.h[5]
\n
"
)
IF_M_GT
(
0
,
"ld1
{
v0.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
2
,
"fmla v10.8h, v5.8h, v3.h[6]
\n
"
)
IF_MN_GT
(
0
,
3
,
"fmla v11.8h, v5.8h, v3.h[7]
\n
"
)
IF_MN_GT
(
0
,
4
,
"fmla v12.8h, v5.8h, v4.h[0]
\n
"
)
IF_MN_GT
(
0
,
5
,
"fmla v13.8h, v5.8h, v4.h[1]
\n
"
)
"ld1
{
v2.8h
}
, [%[b_ptr]], #16
\n
"
IF_MN_GT
(
0
,
6
,
"fmla v14.8h, v5.8h, v4.h[2]
\n
"
)
IF_MN_GT
(
0
,
7
,
"fmla v15.8h, v5.8h, v4.h[3]
\n
"
)
IF_MN_GT
(
0
,
8
,
"fmla v16.8h, v5.8h, v4.h[4]
\n
"
)
IF_MN_GT
(
0
,
9
,
"fmla v17.8h, v5.8h, v4.h[5]
\n
"
)
IF_MN_GT
(
0
,
10
,
"fmla v18.8h, v5.8h, v4.h[6]
\n
"
)
IF_MN_GT
(
0
,
11
,
"fmla v19.8h, v5.8h, v4.h[7]
\n
"
)
IF_MN_GT
(
1
,
0
,
"fmla v20.8h, v6.8h, v3.h[4]
\n
"
)
IF_MN_GT
(
1
,
1
,
"fmla v21.8h, v6.8h, v3.h[5]
\n
"
)
IF_MN_GT
(
1
,
2
,
"fmla v22.8h, v6.8h, v3.h[6]
\n
"
)
IF_MN_GT
(
1
,
3
,
"fmla v23.8h, v6.8h, v3.h[7]
\n
"
)
IF_MN_GT
(
1
,
4
,
"fmla v24.8h, v6.8h, v4.h[0]
\n
"
)
IF_MN_GT
(
1
,
5
,
"fmla v25.8h, v6.8h, v4.h[1]
\n
"
)
"subs %w[K], %w[K], #1
\n
"
IF_MN_GT
(
1
,
6
,
"fmla v26.8h, v6.8h, v4.h[2]
\n
"
)
IF_MN_GT
(
1
,
7
,
"fmla v27.8h, v6.8h, v4.h[3]
\n
"
)
IF_MN_GT
(
1
,
8
,
"fmla v28.8h, v6.8h, v4.h[4]
\n
"
)
IF_MN_GT
(
1
,
9
,
"fmla v29.8h, v6.8h, v4.h[5]
\n
"
)
IF_MN_GT
(
1
,
10
,
"fmla v30.8h, v6.8h, v4.h[6]
\n
"
)
IF_MN_GT
(
1
,
11
,
"fmla v31.8h, v6.8h, v4.h[7]
\n
"
)
"bne 3b
\n
"
"4:
\n
"
"cmp %w[oddK], #1
\n
"
"beq 5f
\n
"
// even tail
"ld1
{
v3.8h
}
, [%[b_ptr]], #16
\n
"
IF_MN_GT
(
0
,
0
,
"fmla v8.8h, v0.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
0
,
1
,
"fmla v9.8h, v0.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
0
,
2
,
"fmla v10.8h, v0.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
0
,
3
,
"fmla v11.8h, v0.8h, v2.h[3]
\n
"
)
IF_M_GT
(
1
,
"ld1
{
v1.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
4
,
"fmla v12.8h, v0.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
0
,
5
,
"fmla v13.8h, v0.8h, v2.h[5]
\n
"
)
IF_MN_GT
(
0
,
6
,
"fmla v14.8h, v0.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
0
,
7
,
"fmla v15.8h, v0.8h, v2.h[7]
\n
"
)
IF_M_GT
(
0
,
"ld1
{
v5.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
8
,
"fmla v16.8h, v0.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
0
,
9
,
"fmla v17.8h, v0.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
0
,
10
,
"fmla v18.8h, v0.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
0
,
11
,
"fmla v19.8h, v0.8h, v3.h[3]
\n
"
)
IF_MN_GT
(
1
,
0
,
"fmla v20.8h, v1.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
1
,
1
,
"fmla v21.8h, v1.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
1
,
2
,
"fmla v22.8h, v1.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
1
,
3
,
"fmla v23.8h, v1.8h, v2.h[3]
\n
"
)
IF_MN_GT
(
1
,
4
,
"fmla v24.8h, v1.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
1
,
5
,
"fmla v25.8h, v1.8h, v2.h[5]
\n
"
)
"ld1
{
v4.8h
}
, [%[b_ptr]], #16
\n
"
IF_MN_GT
(
1
,
6
,
"fmla v26.8h, v1.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
1
,
7
,
"fmla v27.8h, v1.8h, v2.h[7]
\n
"
)
IF_MN_GT
(
1
,
8
,
"fmla v28.8h, v1.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
1
,
9
,
"fmla v29.8h, v1.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
1
,
10
,
"fmla v30.8h, v1.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
1
,
11
,
"fmla v31.8h, v1.8h, v3.h[3]
\n
"
)
IF_M_GT
(
1
,
"ld1
{
v6.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
0
,
"fmla v8.8h, v5.8h, v3.h[4]
\n
"
)
IF_MN_GT
(
0
,
1
,
"fmla v9.8h, v5.8h, v3.h[5]
\n
"
)
IF_MN_GT
(
0
,
2
,
"fmla v10.8h, v5.8h, v3.h[6]
\n
"
)
IF_MN_GT
(
0
,
3
,
"fmla v11.8h, v5.8h, v3.h[7]
\n
"
)
IF_MN_GT
(
0
,
4
,
"fmla v12.8h, v5.8h, v4.h[0]
\n
"
)
IF_MN_GT
(
0
,
5
,
"fmla v13.8h, v5.8h, v4.h[1]
\n
"
)
IF_MN_GT
(
0
,
6
,
"fmla v14.8h, v5.8h, v4.h[2]
\n
"
)
IF_MN_GT
(
0
,
7
,
"fmla v15.8h, v5.8h, v4.h[3]
\n
"
)
IF_MN_GT
(
0
,
0
,
"st1
{
v8.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
8
,
"fmla v16.8h, v5.8h, v4.h[4]
\n
"
)
IF_MN_GT
(
0
,
1
,
"st1
{
v9.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
9
,
"fmla v17.8h, v5.8h, v4.h[5]
\n
"
)
IF_MN_GT
(
0
,
2
,
"st1
{
v10.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
10
,
"fmla v18.8h, v5.8h, v4.h[6]
\n
"
)
IF_MN_GT
(
0
,
3
,
"st1
{
v11.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
11
,
"fmla v19.8h, v5.8h, v4.h[7]
\n
"
)
IF_MN_GT
(
0
,
4
,
"st1
{
v12.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
0
,
"fmla v20.8h, v6.8h, v3.h[4]
\n
"
)
IF_MN_GT
(
0
,
5
,
"st1
{
v13.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
1
,
"fmla v21.8h, v6.8h, v3.h[5]
\n
"
)
IF_MN_GT
(
0
,
6
,
"st1
{
v14.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
2
,
"fmla v22.8h, v6.8h, v3.h[6]
\n
"
)
IF_MN_GT
(
0
,
7
,
"st1
{
v15.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
3
,
"fmla v23.8h, v6.8h, v3.h[7]
\n
"
)
IF_MN_GT
(
0
,
8
,
"st1
{
v16.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
4
,
"fmla v24.8h, v6.8h, v4.h[0]
\n
"
)
IF_MN_GT
(
0
,
9
,
"st1
{
v17.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
5
,
"fmla v25.8h, v6.8h, v4.h[1]
\n
"
)
IF_MN_GT
(
0
,
10
,
"st1
{
v18.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
6
,
"fmla v26.8h, v6.8h, v4.h[2]
\n
"
)
IF_MN_GT
(
0
,
11
,
"st1
{
v19.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
7
,
"fmla v27.8h, v6.8h, v4.h[3]
\n
"
)
IF_MN_GT
(
1
,
0
,
"st1
{
v20.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
8
,
"fmla v28.8h, v6.8h, v4.h[4]
\n
"
)
IF_MN_GT
(
1
,
1
,
"st1
{
v21.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
9
,
"fmla v29.8h, v6.8h, v4.h[5]
\n
"
)
IF_MN_GT
(
1
,
2
,
"st1
{
v22.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
10
,
"fmla v30.8h, v6.8h, v4.h[6]
\n
"
)
IF_MN_GT
(
1
,
3
,
"st1
{
v23.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
11
,
"fmla v31.8h, v6.8h, v4.h[7]
\n
"
)
IF_MN_GT
(
1
,
4
,
"st1
{
v24.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
5
,
"st1
{
v25.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
6
,
"st1
{
v26.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
7
,
"st1
{
v27.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
8
,
"st1
{
v28.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
9
,
"st1
{
v29.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
10
,
"st1
{
v30.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
11
,
"st1
{
v31.8h
}
, [%[outptr1]], #16
\n
"
)
"b 6f
\n
"
"5:
\n
"
// odd tail
"ld1
{
v3.4h
}
, [%[b_ptr]], #8
\n
"
IF_MN_GT
(
0
,
0
,
"fmla v8.8h, v0.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
0
,
1
,
"fmla v9.8h, v0.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
0
,
2
,
"fmla v10.8h, v0.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
0
,
3
,
"fmla v11.8h, v0.8h, v2.h[3]
\n
"
)
IF_MN_GT
(
0
,
4
,
"fmla v12.8h, v0.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
0
,
5
,
"fmla v13.8h, v0.8h, v2.h[5]
\n
"
)
IF_M_GT
(
1
,
"ld1
{
v1.8h
}
, [%[a_ptr]], #16
\n
"
)
IF_MN_GT
(
0
,
6
,
"fmla v14.8h, v0.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
0
,
7
,
"fmla v15.8h, v0.8h, v2.h[7]
\n
"
)
IF_MN_GT
(
0
,
0
,
"st1
{
v8.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
8
,
"fmla v16.8h, v0.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
0
,
1
,
"st1
{
v9.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
9
,
"fmla v17.8h, v0.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
0
,
2
,
"st1
{
v10.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
10
,
"fmla v18.8h, v0.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
0
,
3
,
"st1
{
v11.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
0
,
11
,
"fmla v19.8h, v0.8h, v3.h[3]
\n
"
)
IF_MN_GT
(
0
,
4
,
"st1
{
v12.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
0
,
"fmla v20.8h, v1.8h, v2.h[0]
\n
"
)
IF_MN_GT
(
0
,
5
,
"st1
{
v13.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
1
,
"fmla v21.8h, v1.8h, v2.h[1]
\n
"
)
IF_MN_GT
(
0
,
6
,
"st1
{
v14.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
2
,
"fmla v22.8h, v1.8h, v2.h[2]
\n
"
)
IF_MN_GT
(
0
,
7
,
"st1
{
v15.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
3
,
"fmla v23.8h, v1.8h, v2.h[3]
\n
"
)
IF_MN_GT
(
0
,
8
,
"st1
{
v16.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
4
,
"fmla v24.8h, v1.8h, v2.h[4]
\n
"
)
IF_MN_GT
(
0
,
9
,
"st1
{
v17.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
5
,
"fmla v25.8h, v1.8h, v2.h[5]
\n
"
)
IF_MN_GT
(
0
,
10
,
"st1
{
v18.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
6
,
"fmla v26.8h, v1.8h, v2.h[6]
\n
"
)
IF_MN_GT
(
0
,
11
,
"st1
{
v19.8h
}
, [%[outptr0]], #16
\n
"
)
IF_MN_GT
(
1
,
7
,
"fmla v27.8h, v1.8h, v2.h[7]
\n
"
)
IF_MN_GT
(
1
,
0
,
"st1
{
v20.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
8
,
"fmla v28.8h, v1.8h, v3.h[0]
\n
"
)
IF_MN_GT
(
1
,
1
,
"st1
{
v21.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
9
,
"fmla v29.8h, v1.8h, v3.h[1]
\n
"
)
IF_MN_GT
(
1
,
2
,
"st1
{
v22.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
10
,
"fmla v30.8h, v1.8h, v3.h[2]
\n
"
)
IF_MN_GT
(
1
,
3
,
"st1
{
v23.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
11
,
"fmla v31.8h, v1.8h, v3.h[3]
\n
"
)
IF_MN_GT
(
1
,
4
,
"st1
{
v24.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
5
,
"st1
{
v25.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
6
,
"st1
{
v26.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
7
,
"st1
{
v27.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
8
,
"st1
{
v28.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
9
,
"st1
{
v29.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
10
,
"st1
{
v30.8h
}
, [%[outptr1]], #16
\n
"
)
IF_MN_GT
(
1
,
11
,
"st1
{
v31.8h
}
, [%[outptr1]], #16
\n
"
)
"6:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddK
]
"+r"
(
oddK
),
[
outptr0
]
"+r"
(
outptr0
),
[
outptr1
]
"+r"
(
outptr1
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
,
"x1"
,
"x2"
,
"cc"
,
"memory"
);
#undef IF_MN_GT
#undef IF_N_GT
#undef IF_M_GT
}
#endif
\ No newline at end of file
dnn/src/aarch64/matrix_mul/fp16/strategy.h
浏览文件 @
8fe8edf4
...
...
@@ -9,6 +9,9 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY
(
dt_float16
,
dt_float16
,
dt_float16
,
8
,
24
,
1
,
false
,
true
,
hgemm_8x24
);
MEGDNN_REG_GEMM_STRATEGY
(
dt_float16
,
dt_float16
,
dt_float16
,
16
,
12
,
1
,
false
,
false
,
hgemm_mk8_16x12
);
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
dt_float16
,
dt_float16
,
dt_float16
,
8
,
8
,
1
,
false
,
true
,
gemm_nopack_f16_8x8
);
...
...
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_16x12.cpp
0 → 100644
浏览文件 @
8fe8edf4
#include "src/aarch64/matrix_mul/fp16/kernel_mk8_16x12.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using
namespace
megdnn
;
using
namespace
aarch64
;
using
namespace
aarch64
::
matmul
;
typedef
void
(
*
kern_func
)(
const
dt_float16
*
,
const
dt_float16
*
,
int
,
dt_float16
*
,
int
,
bool
);
static
kern_func
kern_func_table
[
2
][
12
]
=
{
{
matmul_mk8_16x12
::
kern
<
1
,
1
>
,
matmul_mk8_16x12
::
kern
<
1
,
2
>
,
matmul_mk8_16x12
::
kern
<
1
,
3
>
,
matmul_mk8_16x12
::
kern
<
1
,
4
>
,
matmul_mk8_16x12
::
kern
<
1
,
5
>
,
matmul_mk8_16x12
::
kern
<
1
,
6
>
,
matmul_mk8_16x12
::
kern
<
1
,
7
>
,
matmul_mk8_16x12
::
kern
<
1
,
8
>
,
matmul_mk8_16x12
::
kern
<
1
,
9
>
,
matmul_mk8_16x12
::
kern
<
1
,
10
>
,
matmul_mk8_16x12
::
kern
<
1
,
11
>
,
matmul_mk8_16x12
::
kern
<
1
,
12
>
},
{
matmul_mk8_16x12
::
kern
<
2
,
1
>
,
matmul_mk8_16x12
::
kern
<
2
,
2
>
,
matmul_mk8_16x12
::
kern
<
2
,
3
>
,
matmul_mk8_16x12
::
kern
<
2
,
4
>
,
matmul_mk8_16x12
::
kern
<
2
,
5
>
,
matmul_mk8_16x12
::
kern
<
2
,
6
>
,
matmul_mk8_16x12
::
kern
<
2
,
7
>
,
matmul_mk8_16x12
::
kern
<
2
,
8
>
,
matmul_mk8_16x12
::
kern
<
2
,
9
>
,
matmul_mk8_16x12
::
kern
<
2
,
10
>
,
matmul_mk8_16x12
::
kern
<
2
,
11
>
,
matmul_mk8_16x12
::
kern
<
2
,
12
>
}};
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
hgemm_mk8_16x12
);
void
hgemm_mk8_16x12
::
pack_A
(
dt_float16
*
out
,
const
dt_float16
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose_A
)
const
{
megdnn_assert
(
!
transpose_A
,
"mk8 float16 matmul not support transpose A"
);
matmul_mk8_16x12
::
hgemm_16x12_pack_A
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
hgemm_mk8_16x12
::
pack_B
(
dt_float16
*
out
,
const
dt_float16
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose_B
)
const
{
megdnn_assert
(
!
transpose_B
,
"mk8 float16 matmul not support transpose B"
);
matmul_mk8_16x12
::
hgemm_16x12_pack_B
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
// Overview of register layout:
//
// A 12x2 cell of Rhs is stored in 16bit in q2-q4.
// A 2x16 cell of Lhs is stored in 16bit in q0-q1 and q5-q6
// A 12x16 block of accumulators is stored in 16bit in q8--q31.
//
// +----+----+
// | v0 | v1 |
// Rhs +----+----+
// | v5 | v6 |
// +----+----+
//
// | | |
//
// Lhs | | |
//
// +---------------+---------------+ - - - - +----+----+
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v8 | v20|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v9 | v21|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v10| v22|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v11| v23|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v12| v24|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v13| v25|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v14| v26|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v15| v27|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v16| v28|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v17| v29|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v18| v30|
// |v2[0-7] v3[0-3]|v3[4-7] v4[0-7]| | v19| v31|
// +---------------+---------------+ - - - - +----+----+
//
// Accumulator
void
hgemm_mk8_16x12
::
kern
(
const
dt_float16
*
packedA
,
const
dt_float16
*
packedB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_float16
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_float16
*
,
dt_float16
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
C_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Float16
);
const
size_t
K16
=
K
*
16
;
const
size_t
K8
=
K
*
8
;
const
size_t
K12
=
K
*
12
;
constexpr
size_t
PACK_C_SIZE
=
8
;
constexpr
size_t
A_BLOCK
=
16
;
constexpr
size_t
B_BLOCK
=
12
;
size_t
m
=
0
;
for
(;
m
<
M
;
m
+=
A_BLOCK
)
{
dt_float16
*
outptr
=
C
+
(
m
/
PACK_C_SIZE
*
LDC
);
const
size_t
m_func_idx
=
std
::
min
<
size_t
>
(
M
-
m
,
A_BLOCK
)
/
8
-
1
;
size_t
n
=
0
;
const
dt_float16
*
cur_packedB
=
packedB
;
for
(;
n
<
N
;
n
+=
B_BLOCK
)
{
const
size_t
n_func_idx
=
std
::
min
<
size_t
>
(
N
-
n
,
B_BLOCK
)
-
1
;
kern_func_table
[
m_func_idx
][
n_func_idx
](
packedA
,
cur_packedB
,
K
,
outptr
,
LDC
,
is_first_k
);
cur_packedB
+=
K12
;
outptr
+=
B_BLOCK
*
PACK_C_SIZE
;
}
packedA
+=
(
m_func_idx
?
K16
:
K8
);
}
}
#endif
\ No newline at end of file
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
8fe8edf4
...
...
@@ -15,6 +15,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K8x24x1
f16_k8x24x1
;
AlgoF16MK8_8x8
f16_mk8_8x8
;
AlgoF16MK8_16x12x1
f16_mk8_16x12x1
;
#endif
#if MGB_ENABLE_DOT
AlgoInt8x8x32K8x12x4DotProd
int8x8x32_k8x12x4_dotprod
;
...
...
@@ -52,6 +53,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
m_all_algos
.
emplace_back
(
&
f16_k8x24x1
);
m_all_algos
.
emplace_back
(
&
f16_mk8_8x8
);
m_all_algos
.
emplace_back
(
&
f16_mk8_16x12x1
);
#endif
#if MGB_ENABLE_DOT
m_all_algos
.
emplace_back
(
&
int8x8x32_k8x12x4_dotprod
);
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
8fe8edf4
...
...
@@ -27,6 +27,7 @@ private:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class
AlgoF16K8x24x1
;
// Aarch64 F16 Kernel 8x24x1
class
AlgoF16MK8_8x8
;
// Aarch64 F16 Format MK8 block 16x8
class
AlgoF16MK8_16x12x1
;
// Aarch64 F16 Format MK8 block 16x12x1
#endif
#if MGB_ENABLE_DOT
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
8fe8edf4
...
...
@@ -136,6 +136,7 @@ public:
AARCH64_F32_GEMV
,
AARCH64_F16_K8X24X1
,
AARCH64_F16_MK8_8X8
,
AARCH64_F16_MK8_16X12X1
,
AARCH64_INT8X8X32_K8X12X4_DOTPROD
,
AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD
,
AARCH64_INT8X8X32_MK4_4X4X16
,
...
...
dnn/test/aarch64/matrix_mul.cpp
浏览文件 @
8fe8edf4
...
...
@@ -74,6 +74,12 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) {
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
handle
(),
"AARCH64_F16_MK8_8X8"
,
param
::
MatrixMul
::
Format
::
MK8
,
1
);
}
TEST_F
(
AARCH64
,
MATRIX_MUL_F16_MK8_16x12x1
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
handle
(),
"AARCH64_F16_MK8_16X12X1"
,
param
::
MatrixMul
::
Format
::
MK8
,
1
);
}
#endif
#if MGB_ENABLE_DOT
...
...
@@ -790,6 +796,14 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) {
"AARCH64_F16_MK8_8X8"
,
param
::
MatrixMul
::
Format
::
MK8
,
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
"AARCH64_F16_K8X24X1"
);
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_F16_MK8_16x12
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
8
);
matrix_mul
::
benchmark_with_contrast
(
handle
(),
args
,
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
"AARCH64_F16_MK8_16X12X1"
,
param
::
MatrixMul
::
Format
::
MK8
,
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
"AARCH64_F16_K8X24X1"
);
}
#endif
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_INT16x16x32
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录