Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
92b12685
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
92b12685
编写于
9月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/aarch64): add aarch64 int8X8X16_mk4_k8x8x8 matmul, performance is better
GitOrigin-RevId: b6af21e8e314b4edd62f0fddcf8578d2eaa0fc2a
上级
5ee1a1c4
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1755 addition
and
0 deletion
+1755
-0
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+70
-0
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+16
-0
dnn/src/aarch64/matrix_mul/asm/common.h
dnn/src/aarch64/matrix_mul/asm/common.h
+56
-0
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
+1451
-0
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
+78
-0
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
+2
-0
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+1
-0
dnn/test/aarch64/matrix_mul.cpp
dnn/test/aarch64/matrix_mul.cpp
+79
-0
未找到文件。
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
92b12685
...
...
@@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
int32_t
);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace
{
void
int8x8x16_mk4_8x8x8_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"int8x8x16_mk4_8x8x8_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
dt_int8
>
(),
Bptr
=
kern_param
.
B
<
dt_int8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int16
>
();
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoInt8x8x16MK4_K8x8x8
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
can_be_treated_as_int8x8x16
(
kern_size_param
)
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK4
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
M
%
4
==
0
&&
kern_size_param
.
K
%
4
==
0
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16MK4_K8x8x8
::
preferred
(
const
KernSizeParam
&
)
const
{
return
true
;
}
size_t
MatrixMulImpl
::
AlgoInt8x8x16MK4_K8x8x8
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoInt8x8x16_MK4_8x8x8::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
gemm_s8x8x16_mk4_8x8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
return
0
;
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x16MK4_K8x8x8
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
int8x8x16_mk4_8x8x8_kern
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt8x8x16MK4_K8x8x8
,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16MK4_K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
,
int8_t
,
int16_t
);
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
92b12685
...
...
@@ -202,6 +202,22 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x16MK4_K8x8x8
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH64_INT8X8X16_MK4_K8X8X8"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x16MK4_4x4x8
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/aarch64/matrix_mul/asm/common.h
浏览文件 @
92b12685
...
...
@@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) {
vreinterpretq_s32_s8
(
input2
),
3
);
}
template
<
typename
T
>
static
inline
void
interleave_8x8_mk4_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
T
*&
outptr
)
{
static_assert
(
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
,
"transpose_8x4_1_b only support uint8_t and int8_t"
);
asm
volatile
(
"ld1 {v0.4s}, [%[inptr0]], #16
\n
"
"ld1 {v1.4s}, [%[inptr1]], #16
\n
"
"ld1 {v2.4s}, [%[inptr0]], #16
\n
"
"ld1 {v3.4s}, [%[inptr1]], #16
\n
"
"zip1 v4.4s, v0.4s, v1.4s
\n
"
"zip2 v5.4s, v0.4s, v1.4s
\n
"
"zip1 v6.4s, v2.4s, v3.4s
\n
"
"zip2 v7.4s, v2.4s, v3.4s
\n
"
"st1 {v4.4s},[%[outptr]],#16
\n
"
"st1 {v5.4s},[%[outptr]],#16
\n
"
"st1 {v6.4s},[%[outptr]],#16
\n
"
"st1 {v7.4s},[%[outptr]],#16
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_8x8_mk4_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
T
*
outptr
)
{
static_assert
(
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
,
"transpose_8x4_1_b only support uint8_t and int8_t"
);
asm
volatile
(
"ld4 {v0.8b-v3.8b}, [%[inptr0]], #32
\n
"
"ld4 {v4.8b-v7.8b}, [%[inptr1]], #32
\n
"
"st1 {v0.2s},[%[outptr]],#8
\n
"
"st1 {v1.2s},[%[outptr]],#8
\n
"
"st1 {v2.2s},[%[outptr]],#8
\n
"
"st1 {v3.2s},[%[outptr]],#8
\n
"
"st1 {v4.2s},[%[outptr]],#8
\n
"
"st1 {v5.2s},[%[outptr]],#8
\n
"
"st1 {v6.2s},[%[outptr]],#8
\n
"
"st1 {v7.2s},[%[outptr]],#8
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"memory"
);
}
}
// namespace aarch64
}
// namespace megdnn
...
...
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
0 → 100644
浏览文件 @
92b12685
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <inttypes.h>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul_mk4_8x8x8
{
/**
* Overview of register layout:
*
* A 8x8 cell of Lhs is stored in 8bit in v16-v17
* B 8x8 cell of Rhs is stored in 8bit in v0-v15, v20-v23
* C 8x8 block of accumulators is stored in 16bit in v24-v31
*
* +---------------------------------+
* | v0 ------------------------ v7 |
* | v8 ------------------------ v15|
* Rhs +---------------------------------+
* Lhs | |
* +--------+ - - - - +---------------------------------+
* | v16 | | v24 |
* | v17 | | v25 |
* | v16 | | v26 |
* | v17 | | v27 |
* | v16 | | v28 |
* | v17 | | v29 |
* | v16 | | v30 |
* | v17 | | v31 |
* +--------+ - - - - +---------------------------------+
*
* Accumulator
*/
static
void
kern_8x8
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packB
;
//packA;
const
int8_t
*
b_ptr
=
packA
;
//packB;
// clang-format off
#define LOAD_C_8 \
"ld1 {v0.8h}, [x0], #16\n" \
"ld1 {v1.8h}, [x0], #16\n" \
"ld1 {v2.8h}, [x0], #16\n" \
"ld1 {v3.8h}, [x0], #16\n" \
"ld1 {v4.8h}, [x1], #16\n" \
"ld1 {v5.8h}, [x1], #16\n" \
"ld1 {v6.8h}, [x1], #16\n" \
"ld1 {v7.8h}, [x1], #16\n" \
#define STORE_C_8 \
"st1 {v0.8h}, [x0], #16\n" \
"st1 {v1.8h}, [x0], #16\n" \
"st1 {v2.8h}, [x0], #16\n" \
"st1 {v3.8h}, [x0], #16\n" \
"st1 {v4.8h}, [x1], #16\n" \
"st1 {v5.8h}, [x1], #16\n" \
"st1 {v6.8h}, [x1], #16\n" \
"st1 {v7.8h}, [x1], #16\n" \
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"add x1, x0, %x[LDC]
\n
"
"eor v24.16b, v24.16b, v24.16b
\n
"
"PRFM PLDL1KEEP, [%[a_ptr], #512]
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"PRFM PLDL1KEEP, [%[b_ptr], #512]
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"1:
\n
"
"dup v0.8b,v20.b[0]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"dup v1.8b,v20.b[1]
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"dup v2.8b,v20.b[2]
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v3.8b,v20.b[3]
\n
"
"dup v4.8b,v20.b[4]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
LOAD_C_8
"b 3f
\n
"
"2:
\n
"
"eor v0.16b, v0.16b, v0.16b
\n
"
"eor v1.16b, v1.16b, v1.16b
\n
"
"eor v2.16b, v2.16b, v2.16b
\n
"
"eor v3.16b, v3.16b, v3.16b
\n
"
"eor v4.16b, v4.16b, v4.16b
\n
"
"eor v5.16b, v5.16b, v5.16b
\n
"
"eor v6.16b, v6.16b, v6.16b
\n
"
"eor v7.16b, v7.16b, v7.16b
\n
"
"3:
\n
"
"zip1 v8.2d, v24.2d, v25.2d
\n
"
"zip2 v9.2d, v24.2d, v25.2d
\n
"
"zip1 v10.2d, v26.2d, v27.2d
\n
"
"zip2 v11.2d, v26.2d, v27.2d
\n
"
"zip1 v12.2d, v28.2d, v29.2d
\n
"
"zip2 v13.2d, v28.2d, v29.2d
\n
"
"zip1 v14.2d, v30.2d, v31.2d
\n
"
"zip2 v15.2d, v30.2d, v31.2d
\n
"
"add v0.8h, v0.8h, v8.8h
\n
"
"add v1.8h, v1.8h, v10.8h
\n
"
"add v2.8h, v2.8h, v12.8h
\n
"
"add v3.8h, v3.8h, v14.8h
\n
"
"add v4.8h, v4.8h, v9.8h
\n
"
"add v5.8h, v5.8h, v11.8h
\n
"
"add v6.8h, v6.8h, v13.8h
\n
"
"add v7.8h, v7.8h, v15.8h
\n
"
// Store back into memory
STORE_C_8
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
// clang-format on
}
static
void
kern_8x8_remain
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packB
;
const
int8_t
*
b_ptr
=
packA
;
// clang-format off
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"add x1, x0, %x[LDC]
\n
"
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"PRFM PLDL1KEEP, [%[a_ptr], #512]
\n
"
"PRFM PLDL1KEEP, [%[b_ptr], #512]
\n
"
"1:
\n
"
"dup v0.8b,v20.b[0]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"dup v1.8b,v20.b[1]
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"dup v2.8b,v20.b[2]
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v3.8b,v20.b[3]
\n
"
"dup v4.8b,v20.b[4]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
"cmp %x[m_remain], #8
\n
"
"beq 8f
\n
"
"cmp %x[m_remain], #4
\n
"
"beq 9f
\n
"
"8:
\n
"
"cmp %x[n_remain], #8
\n
"
"beq 200f
\n
"
"cmp %x[n_remain], #7
\n
"
"beq 201f
\n
"
"cmp %x[n_remain], #6
\n
"
"beq 202f
\n
"
"cmp %x[n_remain], #5
\n
"
"beq 203f
\n
"
"cmp %x[n_remain], #4
\n
"
"beq 204f
\n
"
"cmp %x[n_remain], #3
\n
"
"beq 205f
\n
"
"cmp %x[n_remain], #2
\n
"
"beq 206f
\n
"
"cmp %x[n_remain], #1
\n
"
"beq 207f
\n
"
"200:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"ld1 {v3.8h}, [x0], #16
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.8h}, [x1], #16
\n
"
"ld1 {v6.8h}, [x1], #16
\n
"
"ld1 {v7.8h}, [x1], #16
\n
"
"b 3f
\n
"
"201:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"ld1 {v3.d}[0], [x0], #8
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.8h}, [x1], #16
\n
"
"ld1 {v6.8h}, [x1], #16
\n
"
"ld1 {v7.d}[0], [x1], #8
\n
"
"b 3f
\n
"
"202:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.8h}, [x1], #16
\n
"
"ld1 {v6.8h}, [x1], #16
\n
"
"b 3f
\n
"
"203:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.d}[0], [x0], #8
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.8h}, [x1], #16
\n
"
"ld1 {v6.d}[0], [x1], #8
\n
"
"b 3f
\n
"
"204:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.8h}, [x1], #16
\n
"
"b 3f
\n
"
"205:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.d}[0], [x0], #8
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"ld1 {v5.d}[0], [x1], #8
\n
"
"b 3f
\n
"
"206:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v4.8h}, [x1], #16
\n
"
"b 3f
\n
"
"207:
\n
"
"ld1 {v0.d}[0], [x0], #8
\n
"
"ld1 {v4.d}[0], [x1], #8
\n
"
"b 3f
\n
"
"9:
\n
"
"cmp %x[n_remain], #8
\n
"
"beq 300f
\n
"
"cmp %x[n_remain], #7
\n
"
"beq 301f
\n
"
"cmp %x[n_remain], #6
\n
"
"beq 302f
\n
"
"cmp %x[n_remain], #5
\n
"
"beq 303f
\n
"
"cmp %x[n_remain], #4
\n
"
"beq 304f
\n
"
"cmp %x[n_remain], #3
\n
"
"beq 305f
\n
"
"cmp %x[n_remain], #2
\n
"
"beq 306f
\n
"
"cmp %x[n_remain], #1
\n
"
"beq 307f
\n
"
"300:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"ld1 {v3.8h}, [x0], #16
\n
"
"b 3f
\n
"
"301:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"ld1 {v3.d}[0], [x0], #8
\n
"
"b 3f
\n
"
"302:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.8h}, [x0], #16
\n
"
"b 3f
\n
"
"303:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"ld1 {v2.d}[0], [x0], #8
\n
"
"b 3f
\n
"
"304:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.8h}, [x0], #16
\n
"
"b 3f
\n
"
"305:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"ld1 {v1.d}[0], [x0], #8
\n
"
"b 3f
\n
"
"306:
\n
"
"ld1 {v0.8h}, [x0], #16
\n
"
"b 3f
\n
"
"307:
\n
"
"ld1 {v0.d}[0], [x0], #8
\n
"
"b 3f
\n
"
"2:
\n
"
"eor v0.16b, v0.16b, v0.16b
\n
"
"eor v1.16b, v1.16b, v1.16b
\n
"
"eor v2.16b, v2.16b, v2.16b
\n
"
"eor v3.16b, v3.16b, v3.16b
\n
"
"eor v4.16b, v4.16b, v4.16b
\n
"
"eor v5.16b, v5.16b, v5.16b
\n
"
"eor v6.16b, v6.16b, v6.16b
\n
"
"eor v7.16b, v7.16b, v7.16b
\n
"
"3:
\n
"
"zip1 v8.2d, v24.2d, v25.2d
\n
"
"zip1 v10.2d, v26.2d, v27.2d
\n
"
"add v0.8h, v0.8h, v8.8h
\n
"
"zip1 v12.2d, v28.2d, v29.2d
\n
"
"add v1.8h, v1.8h, v10.8h
\n
"
"zip1 v14.2d, v30.2d, v31.2d
\n
"
"add v2.8h, v2.8h, v12.8h
\n
"
"add v3.8h, v3.8h, v14.8h
\n
"
"zip2 v9.2d, v24.2d, v25.2d
\n
"
"zip2 v11.2d, v26.2d, v27.2d
\n
"
"add v4.8h, v4.8h, v9.8h
\n
"
"zip2 v13.2d, v28.2d, v29.2d
\n
"
"add v5.8h, v5.8h, v11.8h
\n
"
"zip2 v15.2d, v30.2d, v31.2d
\n
"
"add v6.8h, v6.8h, v13.8h
\n
"
"add v7.8h, v7.8h, v15.8h
\n
"
//save to memory
"cmp %x[m_remain], #8
\n
"
"beq 4f
\n
"
"cmp %x[m_remain], #4
\n
"
"beq 5f
\n
"
"4:
\n
"
"cmp %x[n_remain], #8
\n
"
"beq 100f
\n
"
"cmp %x[n_remain], #7
\n
"
"beq 101f
\n
"
"cmp %x[n_remain], #6
\n
"
"beq 102f
\n
"
"cmp %x[n_remain], #5
\n
"
"beq 103f
\n
"
"cmp %x[n_remain], #4
\n
"
"beq 104f
\n
"
"cmp %x[n_remain], #3
\n
"
"beq 105f
\n
"
"cmp %x[n_remain], #2
\n
"
"beq 106f
\n
"
"cmp %x[n_remain], #1
\n
"
"beq 107f
\n
"
"100:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"st1 {v3.8h}, [x0], #16
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.8h}, [x1], #16
\n
"
"st1 {v6.8h}, [x1], #16
\n
"
"st1 {v7.8h}, [x1], #16
\n
"
"b 1000f
\n
"
"101:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"st1 {v3.d}[0], [x0], #8
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.8h}, [x1], #16
\n
"
"st1 {v6.8h}, [x1], #16
\n
"
"st1 {v7.d}[0], [x1], #8
\n
"
"b 1000f
\n
"
"102:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.8h}, [x1], #16
\n
"
"st1 {v6.8h}, [x1], #16
\n
"
"b 1000f
\n
"
"103:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.d}[0], [x0], #8
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.8h}, [x1], #16
\n
"
"st1 {v6.d}[0], [x1], #8
\n
"
"b 1000f
\n
"
"104:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.8h}, [x1], #16
\n
"
"b 1000f
\n
"
"105:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.d}[0], [x0], #8
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"st1 {v5.d}[0], [x1], #8
\n
"
"b 1000f
\n
"
"106:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v4.8h}, [x1], #16
\n
"
"b 1000f
\n
"
"107:
\n
"
"st1 {v0.d}[0], [x0], #8
\n
"
"st1 {v4.d}[0], [x1], #8
\n
"
"b 1000f
\n
"
"5:
\n
"
"cmp %x[n_remain], #8
\n
"
"beq 200f
\n
"
"cmp %x[n_remain], #7
\n
"
"beq 201f
\n
"
"cmp %x[n_remain], #6
\n
"
"beq 202f
\n
"
"cmp %x[n_remain], #5
\n
"
"beq 203f
\n
"
"cmp %x[n_remain], #4
\n
"
"beq 204f
\n
"
"cmp %x[n_remain], #3
\n
"
"beq 205f
\n
"
"cmp %x[n_remain], #2
\n
"
"beq 206f
\n
"
"cmp %x[n_remain], #1
\n
"
"beq 207f
\n
"
"200:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"st1 {v3.8h}, [x0], #16
\n
"
"b 1000f
\n
"
"201:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"st1 {v3.d}[0], [x0], #8
\n
"
"b 1000f
\n
"
"202:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.8h}, [x0], #16
\n
"
"b 1000f
\n
"
"203:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"st1 {v2.d}[0], [x0], #8
\n
"
"b 1000f
\n
"
"204:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.8h}, [x0], #16
\n
"
"b 1000f
\n
"
"205:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"st1 {v1.d}[0], [x0], #8
\n
"
"b 1000f
\n
"
"206:
\n
"
"st1 {v0.8h}, [x0], #16
\n
"
"b 1000f
\n
"
"207:
\n
"
"st1 {v0.d}[0], [x0], #8
\n
"
"b 1000f
\n
"
"1000:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
// clang-format on
#undef LOAD_C_8
#undef STORE_C_8
}
static
void
kern_4x8
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packB
;
//packA;
const
int8_t
*
b_ptr
=
packA
;
//packB;
// clang-format off
#define LOAD_C_4 \
"ld1 {v0.8h}, [x0], #16\n" \
"ld1 {v1.8h}, [x0], #16\n" \
"ld1 {v2.8h}, [x0], #16\n" \
"ld1 {v3.8h}, [x0], #16\n" \
#define STORE_C_4 \
"st1 {v0.8h}, [x0], #16\n" \
"st1 {v1.8h}, [x0], #16\n" \
"st1 {v2.8h}, [x0], #16\n" \
"st1 {v3.8h}, [x0], #16\n" \
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"PRFM PLDL1KEEP, [%[a_ptr], #512]
\n
"
"PRFM PLDL1KEEP, [%[b_ptr], #512]
\n
"
"1:
\n
"
"dup v0.8b,v20.b[0]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"dup v1.8b,v20.b[1]
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"dup v2.8b,v20.b[2]
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v3.8b,v20.b[3]
\n
"
"dup v4.8b,v20.b[4]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
LOAD_C_4
"b 3f
\n
"
"2:
\n
"
"eor v0.16b, v0.16b, v0.16b
\n
"
"eor v1.16b, v1.16b, v1.16b
\n
"
"eor v2.16b, v2.16b, v2.16b
\n
"
"eor v3.16b, v3.16b, v3.16b
\n
"
"eor v4.16b, v4.16b, v4.16b
\n
"
"eor v5.16b, v5.16b, v5.16b
\n
"
"eor v6.16b, v6.16b, v6.16b
\n
"
"eor v7.16b, v7.16b, v7.16b
\n
"
"3:
\n
"
"zip1 v8.2d, v24.2d, v25.2d
\n
"
"zip1 v10.2d, v26.2d, v27.2d
\n
"
"add v0.8h, v0.8h, v8.8h
\n
"
"zip1 v12.2d, v28.2d, v29.2d
\n
"
"add v1.8h, v1.8h, v10.8h
\n
"
"zip1 v14.2d, v30.2d, v31.2d
\n
"
"add v2.8h, v2.8h, v12.8h
\n
"
"add v3.8h, v3.8h, v14.8h
\n
"
// Store back into memory
STORE_C_4
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
// clang-format on
#undef LOAD_C_4
#undef STORE_C_4
}
static
void
kern_4x8_remain
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packB
;
//packA;
const
int8_t
*
b_ptr
=
packA
;
//packB;
// clang-format off
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"PRFM PLDL1KEEP, [%[a_ptr], #512]
\n
"
"PRFM PLDL1KEEP, [%[b_ptr], #512]
\n
"
"1:
\n
"
"dup v0.8b,v20.b[0]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"dup v1.8b,v20.b[1]
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"dup v2.8b,v20.b[2]
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v3.8b,v20.b[3]
\n
"
"dup v4.8b,v20.b[4]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
"cmp %w[n_remain],#7
\n
"
"beq 200f
\n
"
"cmp %w[n_remain],#6
\n
"
"beq 201f
\n
"
"cmp %w[n_remain],#5
\n
"
"beq 202f
\n
"
"cmp %w[n_remain],#4
\n
"
"beq 203f
\n
"
"cmp %w[n_remain],#3
\n
"
"beq 204f
\n
"
"cmp %w[n_remain],#2
\n
"
"beq 205f
\n
"
"cmp %w[n_remain],#1
\n
"
"beq 206f
\n
"
"200:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"ld1 {v1.8h}, [x0],#16
\n
"
"ld1 {v2.8h}, [x0],#16
\n
"
"ld1 {v3.d}[0], [x0],#8
\n
"
"b 3f
\n
"
"201:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"ld1 {v1.8h}, [x0],#16
\n
"
"ld1 {v2.8h}, [x0],#16
\n
"
"b 3f
\n
"
"202:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"ld1 {v1.8h}, [x0],#16
\n
"
"ld1 {v2.d}[0], [x0],#8
\n
"
"b 3f
\n
"
"203:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"ld1 {v1.8h}, [x0],#16
\n
"
"b 3f
\n
"
"204:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"ld1 {v1.d}[0], [x0],#8
\n
"
"b 3f
\n
"
"205:
\n
"
"ld1 {v0.8h}, [x0],#16
\n
"
"b 3f
\n
"
"206:
\n
"
"ld1 {v0.d}[0], [x0],#8
\n
"
"b 3f
\n
"
"2:
\n
"
"eor v0.16b, v0.16b, v0.16b
\n
"
"eor v1.16b, v1.16b, v1.16b
\n
"
"eor v2.16b, v2.16b, v2.16b
\n
"
"eor v3.16b, v3.16b, v3.16b
\n
"
"eor v4.16b, v4.16b, v4.16b
\n
"
"eor v5.16b, v5.16b, v5.16b
\n
"
"eor v6.16b, v6.16b, v6.16b
\n
"
"eor v7.16b, v7.16b, v7.16b
\n
"
"3:
\n
"
"zip1 v8.2d, v24.2d, v25.2d
\n
"
"zip1 v10.2d, v26.2d, v27.2d
\n
"
"add v0.8h, v0.8h, v8.8h
\n
"
"zip1 v12.2d, v28.2d, v29.2d
\n
"
"add v1.8h, v1.8h, v10.8h
\n
"
"zip1 v14.2d, v30.2d, v31.2d
\n
"
"add v2.8h, v2.8h, v12.8h
\n
"
"add v3.8h, v3.8h, v14.8h
\n
"
// Store back into memory
"cmp %w[n_remain],#7
\n
"
"beq 100f
\n
"
"cmp %w[n_remain],#6
\n
"
"beq 101f
\n
"
"cmp %w[n_remain],#5
\n
"
"beq 102f
\n
"
"cmp %w[n_remain],#4
\n
"
"beq 103f
\n
"
"cmp %w[n_remain],#3
\n
"
"beq 104f
\n
"
"cmp %w[n_remain],#2
\n
"
"beq 105f
\n
"
"cmp %w[n_remain],#1
\n
"
"beq 106f
\n
"
"100:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"st1 {v1.8h}, [x0],#16
\n
"
"st1 {v2.8h}, [x0],#16
\n
"
"st1 {v3.d}[0], [x0],#8
\n
"
"b 1000f
\n
"
"101:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"st1 {v1.8h}, [x0],#16
\n
"
"st1 {v2.8h}, [x0],#16
\n
"
"b 1000f
\n
"
"102:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"st1 {v1.8h}, [x0],#16
\n
"
"st1 {v2.d}[0], [x0],#8
\n
"
"b 1000f
\n
"
"103:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"st1 {v1.8h}, [x0],#16
\n
"
"b 1000f
\n
"
"104:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"st1 {v1.d}[0], [x0],#8
\n
"
"b 1000f
\n
"
"105:
\n
"
"st1 {v0.8h}, [x0],#16
\n
"
"b 1000f
\n
"
"106:
\n
"
"st1 {v0.d}[0], [x0],#8
\n
"
"b 1000f
\n
"
"1000:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
// clang-format on
#undef LOAD_C_4
#undef STORE_C_4
}
//! pack to icxoc
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7))
//! if M K is not times of 8,pack 0 instead
static
void
gemm_s8x8x16_mk4_8x8x8_pack_A
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
m0
%
4
==
0
&&
mmax
%
4
==
0
,
"M must be time of 4"
);
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
constexpr
int
pack_m
=
8
;
constexpr
int
pack_k
=
8
;
constexpr
int
pack_size
=
4
;
int8_t
tmpbuff0
[
pack_m
*
pack_size
]
=
{
0
};
int8_t
tmpbuff1
[
pack_m
*
pack_size
]
=
{
0
};
int8_t
zerobuff
[
pack_m
*
pack_size
]
=
{
0
};
const
int
m_size
=
mmax
-
m0
;
const
int
m_end
=
m_size
/
pack_m
*
pack_m
+
m0
;
int
remain_m
=
mmax
-
m_end
;
for
(
int
m_idx
=
m0
;
m_idx
<
m_end
;
m_idx
+=
pack_m
)
{
const
int8_t
*
inptr0
=
inptr
+
m_idx
/
pack_size
*
ldin
+
k0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
int
k_idx
=
k0
;
for
(
;
k_idx
+
7
<
kmax
;
k_idx
+=
pack_k
)
{
interleave_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
}
if
(
k_idx
<
kmax
)
{
memcpy
(
tmpbuff0
,
inptr0
,
sizeof
(
int8_t
)
*
(
kmax
-
k_idx
)
*
pack_size
);
memcpy
(
tmpbuff1
,
inptr1
,
sizeof
(
int8_t
)
*
(
kmax
-
k_idx
)
*
pack_size
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
interleave_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
}
}
int
m_idx
=
m_end
;
if
(
remain_m
==
4
)
{
const
int8_t
*
inptr0
=
inptr
+
m_idx
/
pack_size
*
ldin
+
k0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
int
k_idx
=
k0
;
for
(
;
k_idx
+
7
<
kmax
;
k_idx
+=
pack_k
)
{
inptr1
=
zerobuff
;
interleave_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
}
if
(
k_idx
<
kmax
)
{
memcpy
(
tmpbuff0
,
inptr0
,
sizeof
(
int8_t
)
*
(
kmax
-
k_idx
)
*
pack_size
);
inptr0
=
tmpbuff0
;
inptr1
=
zerobuff
;
interleave_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
}
}
}
//! pack to nxic
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead.
static
void
gemm_s8x8x16_mk4_8x8x8_pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
n0
,
int
nmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
constexpr
int
pack_n
=
8
;
constexpr
int
pack_k
=
8
;
constexpr
int
pack_size
=
4
;
int8_t
tmpbuff0
[
pack_n
*
pack_size
]
=
{
0
};
int8_t
tmpbuff1
[
pack_n
*
pack_size
]
=
{
0
};
int8_t
zerobuff
[
pack_n
*
pack_size
]
=
{
0
};
const
int
ksize
=
round_up
<
int
>
((
kmax
-
k0
),
8
);
const
int
nsize
=
nmax
-
n0
;
const
int
n_end
=
nsize
/
pack_n
*
pack_n
+
n0
;
const
int
remain_n
=
nsize
%
pack_n
;
int
output_stride
=
ksize
*
pack_n
;
int8_t
*
outptr_base
=
out
;
int
k_idx
=
k0
;
for
(
;
k_idx
+
7
<
kmax
;
k_idx
+=
pack_k
)
{
const
int8_t
*
inptr0
=
in
+
k_idx
/
pack_size
*
ldin
+
n0
*
pack_size
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
prefetch_3x
(
inptr0
);
prefetch_3x
(
inptr1
);
auto
outptr
=
outptr_base
;
for
(
int
n_idx
=
n0
;
n_idx
<
n_end
;
n_idx
+=
pack_n
)
{
transpose_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
output_stride
;
}
if
(
remain_n
>
0
)
{
memcpy
(
tmpbuff0
,
inptr0
,
sizeof
(
int8_t
)
*
remain_n
*
pack_size
);
memcpy
(
tmpbuff1
,
inptr1
,
sizeof
(
int8_t
)
*
remain_n
*
pack_size
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
transpose_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
output_stride
;
}
outptr_base
+=
pack_n
*
pack_k
;
}
if
(
k_idx
<
kmax
){
const
int8_t
*
inptr0
=
in
+
k_idx
/
pack_size
*
ldin
+
n0
*
pack_size
;
const
int8_t
*
inptr1
=
nullptr
;
prefetch_3x
(
inptr0
);
auto
outptr
=
outptr_base
;
for
(
int
n_idx
=
n0
;
n_idx
<
n_end
;
n_idx
+=
pack_n
)
{
inptr1
=
zerobuff
;
transpose_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
output_stride
;
}
if
(
remain_n
>
0
)
{
memcpy
(
tmpbuff0
,
inptr0
,
sizeof
(
int8_t
)
*
remain_n
*
pack_size
);
inptr1
=
zerobuff
;
inptr0
=
tmpbuff0
;
transpose_8x8_mk4_b
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
output_stride
;
}
outptr_base
+=
pack_n
*
pack_size
;
}
}
}
// namespace matmul_mk4_16x12x4_a53
}
// namespace aarch64
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
浏览文件 @
92b12685
...
...
@@ -13,6 +13,7 @@
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
...
...
@@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_mk4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s8x8x16_mk4_8x8x8
);
void
gemm_s8x8x16_mk4_8x8x8
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
)
const
{
matmul_mk4_8x8x8
::
gemm_s8x8x16_mk4_8x8x8_pack_A
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
gemm_s8x8x16_mk4_8x8x8
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
)
const
{
matmul_mk4_8x8x8
::
gemm_s8x8x16_mk4_8x8x8_pack_B
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
void
gemm_s8x8x16_mk4_8x8x8
::
kern
(
const
dt_int8
*
packA
,
const
dt_int8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_int16
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_int16
*
,
dt_int16
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int16
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
);
megdnn_assert
(
is_first_k
==
true
,
"only impl is_first_k"
);
MEGDNN_MARK_USED_VAR
(
A_dtype
);
MEGDNN_MARK_USED_VAR
(
B_dtype
);
MEGDNN_MARK_USED_VAR
(
C_dtype
);
megdnn_assert
(
M
%
4
==
0
&&
K
%
4
==
0
,
"M and K must be time of 4"
);
constexpr
size_t
pack_size
=
4
;
constexpr
size_t
pack_m
=
8
;
constexpr
size_t
pack_n
=
8
;
const
size_t
remain_n
=
N
%
pack_n
;
size_t
remain_m
=
M
%
pack_m
;
K
=
round_up
<
size_t
>
(
K
,
8
);
size_t
KSIZE8
=
K
*
pack_n
;
size_t
m_idx
=
0
;
for
(;
m_idx
+
pack_m
<=
M
;
m_idx
+=
pack_m
)
{
int16_t
*
output
=
C
+
(
m_idx
/
pack_size
*
LDC
);
size_t
n_idx
=
0
;
const
int8_t
*
cur_packB
=
packB
;
for
(;
n_idx
+
pack_n
<=
N
;
n_idx
+=
pack_n
)
{
matmul_mk4_8x8x8
::
kern_8x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
pack_m
,
pack_n
);
output
+=
pack_n
*
pack_size
;
cur_packB
+=
KSIZE8
;
}
if
(
remain_n
>
0
)
{
matmul_mk4_8x8x8
::
kern_8x8_remain
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
pack_m
,
remain_n
);
output
+=
remain_n
*
pack_size
;
cur_packB
+=
KSIZE8
;
}
packA
+=
KSIZE8
;
}
if
(
remain_m
==
4
)
{
int16_t
*
output
=
C
+
(
m_idx
/
pack_size
*
LDC
);
size_t
n_idx
=
0
;
const
int8_t
*
cur_packB
=
packB
;
for
(;
n_idx
+
pack_n
<=
N
;
n_idx
+=
pack_n
)
{
matmul_mk4_8x8x8
::
kern_4x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
4
,
pack_n
);
output
+=
pack_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
if
(
remain_n
>
0
)
{
matmul_mk4_8x8x8
::
kern_4x8_remain
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
4
,
remain_n
);
output
+=
remain_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
}
}
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
浏览文件 @
92b12685
...
...
@@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false,
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int16
,
16
,
12
,
4
,
false
,
false
,
gemm_s8x8x16_mk4_16x12_a53
);
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int16
,
dt_int16
,
8
,
8
,
8
,
false
,
false
,
gemm_s8x8x16_mk4_8x8x8
);
}
// namespace matmul
}
// namespace aarch64
...
...
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
92b12685
...
...
@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16K4x4x16
int8x8x16_k4x4x16
;
AlgoInt8x8x16MK4_16x12x4
int8x8x16_mk4_16x12x4
;
AlgoInt8x8x16MK4_4x4x8
int8x8x16_mk4_4x4x8
;
AlgoInt8x8x16MK4_K8x8x8
int8x8x16_mk4_k8x8x8
;
AlgoInt16x16x32K12x8x1
int16x16x32_k12x8x1
;
AlgoInt16x16x32MK8_8x8
int16x16x32_mk8_8x8
;
...
...
@@ -73,6 +74,7 @@ public:
#endif
all_algos
.
emplace_back
(
&
int8x8x16_k4x4x16
);
all_algos
.
emplace_back
(
&
int8x8x16_k8x8x8
);
all_algos
.
emplace_back
(
&
int8x8x16_mk4_k8x8x8
);
all_algos
.
emplace_back
(
&
int8x8x16_mk4_4x4x8
);
all_algos
.
emplace_back
(
&
int8x8x16_mk4_16x12x4
);
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
92b12685
...
...
@@ -57,6 +57,7 @@ private:
#else
class
AlgoQuint8K8x8x8
;
// Aarch64 Quint8 Kernel 8x8x8
#endif
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int4x4x16 Kernel 4x4x16
class
AlgoPack
;
};
...
...
dnn/test/aarch64/matrix_mul.cpp
浏览文件 @
92b12685
...
...
@@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) {
std
::
move
(
args
));
}
TEST_F
(
AARCH64
,
MATRIX_MUL_INT8x8x16_MK4
)
{
std
::
vector
<
matrix_mul
::
TestArg
>
args
;
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
})
for
(
size_t
n
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
24
})
for
(
size_t
k
:
{
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
})
args
.
emplace_back
(
m
,
n
,
k
,
0
);
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"AARCH64_INT8X8X16_MK4_K8X8X8"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
,
1e-3
,
std
::
move
(
args
));
}
TEST_F
(
AARCH64
,
MATRIX_MUL_MK4_8x8x16_4x4
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"AARCH64_INT8X8X16_MK4_4X4X8"
,
...
...
@@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
run
(
384
,
384
,
384
);
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
param
.
transposeA
=
false
;
param
.
transposeB
=
false
;
Benchmarker
<
MatrixMul
>
benchmarker
(
handle
());
Benchmarker
<
MatrixMul
>
benchmarker_mk4
(
handle
());
Benchmarker
<
MatrixMul
>
benchmarker_mk4_4x4x8
(
handle
());
benchmarker
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_display
(
false
);
benchmarker
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH64_INT8X8X16_K4X4X16"
));
param
.
format
=
MatrixMul
::
Param
::
Format
::
MK4
;
benchmarker_mk4
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH64_INT8X8X16_MK4_K8X8X8"
));
benchmarker_mk4
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_display
(
false
);
benchmarker_mk4_4x4x8
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH64_INT8X8X16_MK4_4X4X8"
));
benchmarker_mk4_4x4x8
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_display
(
false
);
auto
run
=
[
&
](
size_t
M
,
size_t
N
,
size_t
K
)
{
auto
default_used
=
benchmarker
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
mk_used
=
benchmarker_mk4
.
exec
(
{{
M
/
4
,
K
/
4
,
4
,
4
},
{
K
/
4
,
N
,
4
},
{}})
/
RUNS
;
auto
mk4_4x4x8_used
=
benchmarker_mk4_4x4x8
.
exec
(
{{
M
/
4
,
K
/
4
,
4
,
4
},
{
K
/
4
,
N
,
4
},
{}})
/
RUNS
;
float
computations
=
2.
f
*
M
*
K
*
N
*
1e-6
;
printf
(
"run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f
\n
"
,
M
,
K
,
N
,
default_used
,
computations
/
default_used
,
mk_used
,
computations
/
mk_used
,
default_used
/
mk_used
,
computations
/
mk4_4x4x8_used
,
mk4_4x4x8_used
,
mk4_4x4x8_used
/
mk_used
);
};
run
(
384
,
384
,
384
);
run
(
512
,
512
,
512
);
run
(
1024
,
1024
,
384
);
run
(
256
,
256
,
384
);
for
(
int
m
=
32
;
m
<=
512
;
m
*=
2
)
for
(
int
n
=
32
;
n
<=
512
;
n
*=
2
)
for
(
int
k
=
32
;
k
<
512
;
k
*=
2
){
run
(
m
,
n
,
k
);
}
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_INT16_4X4X16
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录