Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
07d1d0ab
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
07d1d0ab
编写于
5月 12, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm64): add fp32 mk4 matmul
GitOrigin-RevId: f6df006547e08ba5b76be984a2fe87cf053c31de
上级
7ba641fe
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
1123 addition
and
2 deletion
+1123
-2
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+61
-0
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+11
-0
dnn/src/aarch64/matrix_mul/asm/common.h
dnn/src/aarch64/matrix_mul/asm/common.h
+65
-0
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
+4
-0
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
+874
-0
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
+77
-0
dnn/src/aarch64/matrix_mul/fp32/strategy.h
dnn/src/aarch64/matrix_mul/fp32/strategy.h
+3
-0
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+1
-0
dnn/src/arm_common/conv_bias/f16/algos.cpp
dnn/src/arm_common/conv_bias/f16/algos.cpp
+2
-0
dnn/src/arm_common/conv_bias/fp32/algos.cpp
dnn/src/arm_common/conv_bias/fp32/algos.cpp
+4
-0
dnn/src/arm_common/conv_bias/int8/algos.cpp
dnn/src/arm_common/conv_bias/int8/algos.cpp
+2
-0
dnn/test/aarch64/matrix_mul.cpp
dnn/test/aarch64/matrix_mul.cpp
+15
-0
dnn/test/common/matrix_mul.cpp
dnn/test/common/matrix_mul.cpp
+2
-2
未找到文件。
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
07d1d0ab
...
...
@@ -86,6 +86,67 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
"AlgoF32K8x12x1Impl"
_hash
,
aarch64
::
matmul
::
sgemm_8x12
,
float
,
float
);
/* ===================== F32_MK4_8X12X1 algo ===================== */
bool
MatrixMulImpl
::
AlgoF32MK4_8x12x1
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
B_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
C_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
A_type
==
dtype
::
Float32
()
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK4
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
M
%
4
==
0
&&
kern_size_param
.
K
%
4
==
0
;
}
size_t
MatrixMulImpl
::
AlgoF32MK4_8x12x1
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoF32MK4_8x12x1::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
aarch64
::
matmul
::
sgemm_mk4_8x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
sgemm_mk4_8x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF32MK4_8x12x1
::
get_kern
(
const
KernSizeParam
&
)
const
{
auto
f32_kern_mk4_8x12
=
[](
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoF32MK4_8x12x1::get_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
float
>
(),
Bptr
=
kern_param
.
B
<
float
>
();
auto
Cptr
=
kern_param
.
C
<
float
>
();
aarch64
::
matmul
::
sgemm_mk4_8x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
sgemm_mk4_8x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
};
return
f32_kern_mk4_8x12
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32MK4_8x12x1
,
megdnn_aarch64_matmul_kern
,
"AlgoF32MK4_8x12x1Impl"
_hash
,
aarch64
::
matmul
::
sgemm_mk4_8x12
,
float
,
float
);
/* ===================== F32K4X16X1 algo ===================== */
bool
MatrixMulImpl
::
AlgoF32K4x16x1
::
usable
(
...
...
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
07d1d0ab
...
...
@@ -29,6 +29,17 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoF32MK4_8x12x1
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH64_F32_MK4_K8X12X1"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoF32K4x16x1
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/aarch64/matrix_mul/asm/common.h
浏览文件 @
07d1d0ab
...
...
@@ -1103,6 +1103,36 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1,
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"cc"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_2x4_4_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_2x4_4_s only support size == 4"
);
asm
volatile
(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64
\n
"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr1]], #64
\n
"
"stp q0, q4, [%[outptr]]
\n
"
"stp q1, q5, [%[outptr], #32]
\n
"
"stp q2, q6, [%[outptr], #64]
\n
"
"stp q3, q7, [%[outptr], #96]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_1x4_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_1x4_4_s only support size == 4"
);
asm
volatile
(
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64
\n
"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_4x8_2_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1479,6 +1509,41 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1,
"v11"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_1x12_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_1x12_4_s only support sizeof(T) == 4"
);
asm
volatile
(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64
\n
"
"ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[inptr0]], #64
\n
"
"ld4 {v8.4s, v9.4s, v10.4s, v11.4s},[%[inptr0]], #64
\n
"
"stp q0, q4, [%[outptr]]
\n
"
"stp q8, q1, [%[outptr], #32]
\n
"
"stp q5, q9, [%[outptr], #64]
\n
"
"stp q2, q6, [%[outptr], #96]
\n
"
"stp q10, q3, [%[outptr], #128]
\n
"
"stp q7, q11, [%[outptr], #160]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_1x4_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_1x4_4_s only support sizeof(T) == 4"
);
asm
volatile
(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64
\n
"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_8x4_1_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
浏览文件 @
07d1d0ab
...
...
@@ -899,6 +899,10 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
:
:
"v0"
,
"v1"
,
"v2"
,
"v5"
,
"v8"
,
"v11"
,
"v14"
,
"v17"
,
"x1"
,
"x2"
,
"x3"
,
"x10"
,
"cc"
,
"memory"
);
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
void
sgemm_8x12_pack_A_n
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
y0
,
...
...
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
0 → 100644
浏览文件 @
07d1d0ab
/**
* \file dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul_mk4_8x12
{
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+--------+--------+
// | v2[0-3]| v3[0-3]| v4[0-3]|
// | v5[0-3]| v6[0-3]| v7[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ --- - +--------+--------+--------+
// |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
// |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
// |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
// |v1| |v20[0-3]|v21[0-3]|v22[0-3]|
// |v1| |v23[0-3]|v24[0-3]|v25[0-3]|
// |v1| |v26[0-3]|v27[0-3]|v28[0-3]|
// |v1| |v29[0-3]|v30[0-3]|v31[0-3]|
// +--+ --- - +--------+--------+--------+
//
// Accumulator
void
kern_8x12
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
)
{
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
float
*
output0
=
output
;
float
*
output1
=
output0
+
LDC
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
asm
volatile
(
"cmp %w[is_first_k], #1
\n
"
"beq 1f
\n
"
"mov x1, %[output0]
\n
"
"mov x2, %[output1]
\n
"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
\n
"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
\n
"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
\n
"
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
\n
"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
\n
"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48
\n
"
"b 2f
\n
"
"1:
\n
"
"eor v8.16b, v8.16b, v8.16b
\n
"
"eor v9.16b, v9.16b, v9.16b
\n
"
"eor v10.16b, v10.16b, v10.16b
\n
"
"prfm pstl1keep, [%[output0]]
\n
"
"eor v11.16b, v11.16b, v11.16b
\n
"
"eor v12.16b, v12.16b, v12.16b
\n
"
"eor v13.16b, v13.16b, v13.16b
\n
"
"prfm pstl1keep, [%[output1]]
\n
"
"eor v14.16b, v14.16b, v14.16b
\n
"
"eor v15.16b, v15.16b, v15.16b
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48
\n
"
"eor v16.16b, v16.16b, v16.16b
\n
"
"eor v17.16b, v17.16b, v17.16b
\n
"
"eor v18.16b, v18.16b, v18.16b
\n
"
"eor v19.16b, v19.16b, v19.16b
\n
"
"eor v20.16b, v20.16b, v20.16b
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"eor v21.16b, v21.16b, v21.16b
\n
"
"eor v22.16b, v22.16b, v22.16b
\n
"
"eor v23.16b, v23.16b, v23.16b
\n
"
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
"2:
\n
"
"cmp %w[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"fmla v20.4s, v1.4s, v2.s[0]
\n
"
"fmla v21.4s, v1.4s, v2.s[1]
\n
"
"fmla v22.4s, v1.4s, v2.s[2]
\n
"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48
\n
"
"fmla v23.4s, v1.4s, v2.s[3]
\n
"
"fmla v24.4s, v1.4s, v3.s[0]
\n
"
"fmla v25.4s, v1.4s, v3.s[1]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v26.4s, v1.4s, v3.s[2]
\n
"
"fmla v27.4s, v1.4s, v3.s[3]
\n
"
"fmla v28.4s, v1.4s, v4.s[0]
\n
"
"fmla v29.4s, v1.4s, v4.s[1]
\n
"
"fmla v30.4s, v1.4s, v4.s[2]
\n
"
"fmla v31.4s, v1.4s, v4.s[3]
\n
"
"fmla v8.4s, v0.4s, v5.s[0]
\n
"
"fmla v9.4s, v0.4s, v5.s[1]
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48
\n
"
"fmla v10.4s, v0.4s, v5.s[2]
\n
"
"fmla v11.4s, v0.4s, v5.s[3]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v12.4s, v0.4s, v6.s[0]
\n
"
"fmla v13.4s, v0.4s, v6.s[1]
\n
"
"fmla v14.4s, v0.4s, v6.s[2]
\n
"
"fmla v15.4s, v0.4s, v6.s[3]
\n
"
"fmla v16.4s, v0.4s, v7.s[0]
\n
"
"fmla v17.4s, v0.4s, v7.s[1]
\n
"
"fmla v18.4s, v0.4s, v7.s[2]
\n
"
"fmla v19.4s, v0.4s, v7.s[3]
\n
"
"fmla v20.4s, v1.4s, v5.s[0]
\n
"
"fmla v21.4s, v1.4s, v5.s[1]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v22.4s, v1.4s, v5.s[2]
\n
"
"fmla v23.4s, v1.4s, v5.s[3]
\n
"
"fmla v24.4s, v1.4s, v6.s[0]
\n
"
"subs %w[K], %w[K], #1
\n
"
"fmla v25.4s, v1.4s, v6.s[1]
\n
"
"fmla v26.4s, v1.4s, v6.s[2]
\n
"
"fmla v27.4s, v1.4s, v6.s[3]
\n
"
"fmla v28.4s, v1.4s, v7.s[0]
\n
"
"fmla v29.4s, v1.4s, v7.s[1]
\n
"
"fmla v30.4s, v1.4s, v7.s[2]
\n
"
"fmla v31.4s, v1.4s, v7.s[3]
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %w[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"fmla v20.4s, v1.4s, v2.s[0]
\n
"
"fmla v21.4s, v1.4s, v2.s[1]
\n
"
"fmla v22.4s, v1.4s, v2.s[2]
\n
"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48
\n
"
"fmla v23.4s, v1.4s, v2.s[3]
\n
"
"fmla v24.4s, v1.4s, v3.s[0]
\n
"
"fmla v25.4s, v1.4s, v3.s[1]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v26.4s, v1.4s, v3.s[2]
\n
"
"fmla v27.4s, v1.4s, v3.s[3]
\n
"
"fmla v28.4s, v1.4s, v4.s[0]
\n
"
"fmla v29.4s, v1.4s, v4.s[1]
\n
"
"fmla v30.4s, v1.4s, v4.s[2]
\n
"
"fmla v31.4s, v1.4s, v4.s[3]
\n
"
"fmla v8.4s, v0.4s, v5.s[0]
\n
"
"fmla v9.4s, v0.4s, v5.s[1]
\n
"
"fmla v10.4s, v0.4s, v5.s[2]
\n
"
"fmla v11.4s, v0.4s, v5.s[3]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v12.4s, v0.4s, v6.s[0]
\n
"
"fmla v13.4s, v0.4s, v6.s[1]
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64
\n
"
"fmla v14.4s, v0.4s, v6.s[2]
\n
"
"fmla v15.4s, v0.4s, v6.s[3]
\n
"
"fmla v16.4s, v0.4s, v7.s[0]
\n
"
"fmla v17.4s, v0.4s, v7.s[1]
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64
\n
"
"fmla v18.4s, v0.4s, v7.s[2]
\n
"
"fmla v19.4s, v0.4s, v7.s[3]
\n
"
"fmla v20.4s, v1.4s, v5.s[0]
\n
"
"fmla v21.4s, v1.4s, v5.s[1]
\n
"
"fmla v22.4s, v1.4s, v5.s[2]
\n
"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64
\n
"
"fmla v23.4s, v1.4s, v5.s[3]
\n
"
"fmla v24.4s, v1.4s, v6.s[0]
\n
"
"fmla v25.4s, v1.4s, v6.s[1]
\n
"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64
\n
"
"fmla v26.4s, v1.4s, v6.s[2]
\n
"
"fmla v27.4s, v1.4s, v6.s[3]
\n
"
"fmla v28.4s, v1.4s, v7.s[0]
\n
"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64
\n
"
"fmla v29.4s, v1.4s, v7.s[1]
\n
"
"fmla v30.4s, v1.4s, v7.s[2]
\n
"
"fmla v31.4s, v1.4s, v7.s[3]
\n
"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"fmla v20.4s, v1.4s, v2.s[0]
\n
"
"fmla v21.4s, v1.4s, v2.s[1]
\n
"
"fmla v22.4s, v1.4s, v2.s[2]
\n
"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64
\n
"
"fmla v23.4s, v1.4s, v2.s[3]
\n
"
"fmla v24.4s, v1.4s, v3.s[0]
\n
"
"fmla v25.4s, v1.4s, v3.s[1]
\n
"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[output1]], #64
\n
"
"fmla v26.4s, v1.4s, v3.s[2]
\n
"
"fmla v27.4s, v1.4s, v3.s[3]
\n
"
"fmla v28.4s, v1.4s, v4.s[0]
\n
"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output1]], #64
\n
"
"fmla v29.4s, v1.4s, v4.s[1]
\n
"
"fmla v30.4s, v1.4s, v4.s[2]
\n
"
"fmla v31.4s, v1.4s, v4.s[3]
\n
"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output1]], #64
\n
"
"6:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output0
]
"+r"
(
output0
),
[
output1
]
"+r"
(
output1
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
,
"x1"
,
"x2"
,
"cc"
,
"memory"
);
}
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+
// | v2[0-3]|
// | v3[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |v0| | v8[0-3]|
// |v0| |v11[0-3]|
// |v0| |v14[0-3]|
// |v0| |v17[0-3]|
// |v1| |v20[0-3]|
// |v1| |v23[0-3]|
// |v1| |v26[0-3]|
// |v1| |v29[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void
kern_8x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
float
*
output0
=
output
;
float
*
output1
=
output0
+
LDC
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
//clang-format off
#define LOAD_C \
"cmp %w[n_remain], #4\n" \
"blt 11f\n" \
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
"b 14f\n" \
"11:\n" \
"cmp %w[n_remain], #3\n" \
"blt 12f\n" \
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
"b 14f\n" \
"12:\n" \
"cmp %w[n_remain], #2\n" \
"blt 13f\n" \
"ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
"ld1 {v12.4s, v13.4s},[%[output1]]\n" \
"b 14f\n" \
"13:\n" \
"ld1 {v8.4s}, [%[output0]]\n" \
"ld1 {v12.4s},[%[output1]]\n" \
"14:\n"
#define STORE_C \
"cmp %w[n_remain], #4\n" \
"blt 21f\n" \
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s, v14.4s, v15.4s},[%[output1]]\n" \
"b 24f\n" \
"21:\n" \
"cmp %w[n_remain], #3\n" \
"blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s, v14.4s},[%[output1]]\n" \
"b 23f\n" \
"22:\n" \
"cmp %w[n_remain], #2\n" \
"blt 23f\n" \
"st1 {v8.4s, v9.4s}, [%[output0]]\n" \
"st1 {v12.4s, v13.4s},[%[output1]]\n" \
"b 24f\n" \
"23:\n" \
"st1 {v8.4s}, [%[output0]]\n" \
"st1 {v12.4s},[%[output1]]\n" \
"24:\n"
//clang-format on
asm
volatile
(
// load accumulator C
"cmp %w[is_first_k], #1
\n
"
"beq 1f
\n
"
LOAD_C
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"ld1 {v2.4s}, [%[b_ptr]], #16
\n
"
"b 2f
\n
"
"1:
\n
"
"eor v8.16b, v8.16b, v8.16b
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"eor v9.16b, v9.16b, v9.16b
\n
"
"eor v10.16b, v10.16b, v10.16b
\n
"
"prfm pstl1keep, [%[output0]]
\n
"
"eor v11.16b, v11.16b, v11.16b
\n
"
"eor v12.16b, v12.16b, v12.16b
\n
"
"prfm pstl1keep, [%[output1]]
\n
"
"eor v13.16b, v13.16b, v13.16b
\n
"
"eor v14.16b, v14.16b, v14.16b
\n
"
"eor v15.16b, v15.16b, v15.16b
\n
"
"ld1 {v2.4s}, [%[b_ptr]], #16
\n
"
"2:
\n
"
"cmp %w[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], #16
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"ld1 {v3.4s}, [%[b_ptr]], #16
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v1.4s, v2.s[0]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"fmla v13.4s, v1.4s, v2.s[1]
\n
"
"fmla v14.4s, v1.4s, v2.s[2]
\n
"
"fmla v15.4s, v1.4s, v2.s[3]
\n
"
"fmla v8.4s, v0.4s, v3.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], #16
\n
"
"fmla v9.4s, v0.4s, v3.s[1]
\n
"
"fmla v10.4s, v0.4s, v3.s[2]
\n
"
"fmla v11.4s, v0.4s, v3.s[3]
\n
"
"ld1 {v2.4s}, [%[b_ptr]], #16
\n
"
"fmla v12.4s, v1.4s, v3.s[0]
\n
"
"subs %w[K], %w[K], #1
\n
"
"fmla v13.4s, v1.4s, v3.s[1]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"fmla v14.4s, v1.4s, v3.s[2]
\n
"
"fmla v15.4s, v1.4s, v3.s[3]
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %w[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], #16
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"ld1 {v3.4s}, [%[b_ptr]], #16
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v1.4s, v2.s[0]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"fmla v13.4s, v1.4s, v2.s[1]
\n
"
"fmla v14.4s, v1.4s, v2.s[2]
\n
"
"fmla v15.4s, v1.4s, v2.s[3]
\n
"
"fmla v8.4s, v0.4s, v3.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], #16
\n
"
"fmla v9.4s, v0.4s, v3.s[1]
\n
"
"fmla v10.4s, v0.4s, v3.s[2]
\n
"
"fmla v11.4s, v0.4s, v3.s[3]
\n
"
"fmla v12.4s, v1.4s, v3.s[0]
\n
"
"fmla v13.4s, v1.4s, v3.s[1]
\n
"
"fmla v14.4s, v1.4s, v3.s[2]
\n
"
"fmla v15.4s, v1.4s, v3.s[3]
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], #16
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v1.4s, v2.s[0]
\n
"
"fmla v13.4s, v1.4s, v2.s[1]
\n
"
"fmla v14.4s, v1.4s, v2.s[2]
\n
"
"fmla v15.4s, v1.4s, v2.s[3]
\n
"
"6:
\n
"
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output0
]
"+r"
(
output0
),
[
output1
]
"+r"
(
output1
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"cc"
,
"memory"
);
#undef LOAD_C
#undef STORE_C
}
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in v2-v7
// A 8x1 cell of Lhs is stored in 32bit in (v0-v1)
// A 8x12 block of accumulators is stored in 32bit in v8-v31.
//
// +--------+--------+--------+
// | v2[0-3]| v3[0-3]| v4[0-3]|
// | v5[0-3]| v6[0-3]| v7[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ --- - +--------+--------+--------+
// |v0| | v8[0-3]| v9[0-3]|v10[0-3]|
// |v0| |v11[0-3]|v12[0-3]|v13[0-3]|
// |v0| |v14[0-3]|v15[0-3]|v16[0-3]|
// |v0| |v17[0-3]|v18[0-3]|v19[0-3]|
// +--+ --- - +--------+--------+--------+
//
// Accumulator
void
kern_4x12
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
)
{
MEGDNN_MARK_USED_VAR
(
LDC
);
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
float
*
output0
=
output
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
asm
volatile
(
"cmp %w[is_first_k], #1
\n
"
"beq 1f
\n
"
"mov x1, %[output0]
\n
"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
\n
"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
\n
"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48
\n
"
"b 2f
\n
"
"1:
\n
"
"eor v8.16b, v8.16b, v8.16b
\n
"
"eor v9.16b, v9.16b, v9.16b
\n
"
"eor v10.16b, v10.16b, v10.16b
\n
"
"prfm pstl1keep, [%[output0]]
\n
"
"eor v11.16b, v11.16b, v11.16b
\n
"
"eor v12.16b, v12.16b, v12.16b
\n
"
"eor v13.16b, v13.16b, v13.16b
\n
"
"eor v14.16b, v14.16b, v14.16b
\n
"
"eor v15.16b, v15.16b, v15.16b
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48
\n
"
"eor v16.16b, v16.16b, v16.16b
\n
"
"eor v17.16b, v17.16b, v17.16b
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"eor v18.16b, v18.16b, v18.16b
\n
"
"eor v19.16b, v19.16b, v19.16b
\n
"
"2:
\n
"
"cmp %w[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"fmla v8.4s, v1.4s, v5.s[0]
\n
"
"fmla v9.4s, v1.4s, v5.s[1]
\n
"
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], 48
\n
"
"fmla v10.4s, v1.4s, v5.s[2]
\n
"
"fmla v11.4s, v1.4s, v5.s[3]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v12.4s, v1.4s, v6.s[0]
\n
"
"fmla v13.4s, v1.4s, v6.s[1]
\n
"
"subs %w[K], %w[K], #1
\n
"
"fmla v14.4s, v1.4s, v6.s[2]
\n
"
"fmla v15.4s, v1.4s, v6.s[3]
\n
"
"fmla v16.4s, v1.4s, v7.s[0]
\n
"
"fmla v17.4s, v1.4s, v7.s[1]
\n
"
"fmla v18.4s, v1.4s, v7.s[2]
\n
"
"fmla v19.4s, v1.4s, v7.s[3]
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %w[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"ld1 {v5.4s, v6.4s, v7.4s}, [%[b_ptr]], #48
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"fmla v8.4s, v1.4s, v5.s[0]
\n
"
"fmla v9.4s, v1.4s, v5.s[1]
\n
"
"fmla v10.4s, v1.4s, v5.s[2]
\n
"
"fmla v11.4s, v1.4s, v5.s[3]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v12.4s, v1.4s, v6.s[0]
\n
"
"fmla v13.4s, v1.4s, v6.s[1]
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64
\n
"
"fmla v14.4s, v1.4s, v6.s[2]
\n
"
"fmla v15.4s, v1.4s, v6.s[3]
\n
"
"fmla v16.4s, v1.4s, v7.s[0]
\n
"
"fmla v17.4s, v1.4s, v7.s[1]
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64
\n
"
"fmla v18.4s, v1.4s, v7.s[2]
\n
"
"fmla v19.4s, v1.4s, v7.s[3]
\n
"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v12.4s, v0.4s, v3.s[0]
\n
"
"fmla v13.4s, v0.4s, v3.s[1]
\n
"
"fmla v14.4s, v0.4s, v3.s[2]
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]], #64
\n
"
"fmla v15.4s, v0.4s, v3.s[3]
\n
"
"fmla v16.4s, v0.4s, v4.s[0]
\n
"
"fmla v17.4s, v0.4s, v4.s[1]
\n
"
"fmla v18.4s, v0.4s, v4.s[2]
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[output0]], #64
\n
"
"fmla v19.4s, v0.4s, v4.s[3]
\n
"
"st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[output0]], #64
\n
"
"6:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output0
]
"+r"
(
output0
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"x1"
,
"cc"
,
"memory"
);
}
// Overview of register layout:
//
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1
// A 4x4 block of accumulators is stored in 32bit in v4-v6
//
// +--------+
// | v2[0-3]|
// | v5[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |v0| | v8[0-3]|
// |v0| |v11[0-3]|
// |v0| |v14[0-3]|
// |v0| |v17[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void
kern_4x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
MEGDNN_MARK_USED_VAR
(
LDC
);
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
float
*
output0
=
output
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
//clang-format off
#define LOAD_C \
"cmp %w[n_remain], #4\n" \
"blt 11f\n" \
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"b 14f\n" \
"11:\n" \
"cmp %w[n_remain], #3\n" \
"blt 12f\n" \
"ld1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 14f\n" \
"12:\n" \
"cmp %w[n_remain], #2\n" \
"blt 13f\n" \
"ld1 {v8.4s, v9.4s}, [%[output0]]\n" \
"b 14f\n" \
"13:\n" \
"ld1 {v8.4s}, [%[output0]]\n" \
"14:\n"
#define STORE_C \
"cmp %w[n_remain], #4\n" \
"blt 21f\n" \
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[output0]]\n" \
"b 24f\n" \
"21:\n" \
"cmp %w[n_remain], #3\n" \
"blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 23f\n" \
"22:\n" \
"cmp %w[n_remain], #2\n" \
"blt 23f\n" \
"st1 {v8.4s, v9.4s}, [%[output0]]\n" \
"b 24f\n" \
"23:\n" \
"st1 {v8.4s}, [%[output0]]\n" \
"24:\n"
//clang-format on
asm
volatile
(
// load accumulator C
"cmp %w[is_first_k], #1
\n
"
"beq 1f
\n
"
LOAD_C
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"ld1 {v2.4s}, [%[b_ptr]], #16
\n
"
"b 2f
\n
"
"1:
\n
"
"eor v8.16b, v8.16b, v8.16b
\n
"
"ld1 {v2.4s}, [%[b_ptr]], #16
\n
"
"eor v9.16b, v9.16b, v9.16b
\n
"
"ld1 {v0.4s}, [%[a_ptr]], #16
\n
"
"eor v10.16b, v10.16b, v10.16b
\n
"
"prfm pstl1keep, [%[output0]]
\n
"
"eor v11.16b, v11.16b, v11.16b
\n
"
"2:
\n
"
"cmp %w[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"ld1 {v3.4s}, [%[b_ptr]], 16
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v8.4s, v1.4s, v3.s[0]
\n
"
"fmla v9.4s, v1.4s, v3.s[1]
\n
"
"ld1 {v0.4s}, [%[a_ptr]], 16
\n
"
"fmla v10.4s, v1.4s, v3.s[2]
\n
"
"fmla v11.4s, v1.4s, v3.s[3]
\n
"
"ld1 {v2.4s}, [%[b_ptr]], 16
\n
"
"subs %w[K], %w[K], #1
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %w[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"ld1 {v1.4s}, [%[a_ptr]], 16
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"ld1 {v3.4s}, [%[b_ptr]], 16
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"fmla v8.4s, v1.4s, v3.s[0]
\n
"
"fmla v9.4s, v1.4s, v3.s[1]
\n
"
"fmla v10.4s, v1.4s, v3.s[2]
\n
"
"fmla v11.4s, v1.4s, v3.s[3]
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"fmla v8.4s, v0.4s, v2.s[0]
\n
"
"fmla v9.4s, v0.4s, v2.s[1]
\n
"
"fmla v10.4s, v0.4s, v2.s[2]
\n
"
"fmla v11.4s, v0.4s, v2.s[3]
\n
"
"6:
\n
"
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output0
]
"+r"
(
output0
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"cc"
,
"memory"
);
#undef LOAD_C
#undef STORE_C
}
void
sgemm_8x12_pack_A
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
y0
%
4
==
0
&&
ymax
%
4
==
0
,
"M must be time of 4"
);
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
constexpr
int
PACK_SIZE_32
=
4
*
8
;
constexpr
int
PACK_SIZE_16
=
4
*
4
;
constexpr
int
PACK_C_SIZE
=
4
;
int
y
=
y0
;
for
(;
y
+
7
<
ymax
;
y
+=
8
)
{
const
float
*
inptr0
=
inptr
+
y
/
PACK_C_SIZE
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
int
k
=
(
kmax
-
k0
);
for
(;
k
>
3
;
k
-=
4
)
{
interleave_2x4_4_s
(
inptr0
,
inptr1
,
outptr
);
outptr
+=
PACK_SIZE_32
;
}
}
for
(;
y
<
ymax
;
y
+=
4
)
{
const
float
*
inptr0
=
inptr
+
y
/
PACK_C_SIZE
*
ldin
+
k0
;
prefetch_2x
(
inptr0
);
int
K
=
(
kmax
-
k0
);
for
(;
K
>
3
;
K
-=
4
)
{
interleave_1x4_4_s
(
inptr0
,
outptr
);
outptr
+=
PACK_SIZE_16
;
}
}
}
void
sgemm_8x12_pack_B
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
float
tmpbuff
[
16
]
=
{
0.0
f
};
constexpr
int
PACK_C_SIZE
=
4
;
int
ksize
=
kmax
-
k0
;
int
ksize12
=
ksize
*
12
;
int
ksize4
=
(
ksize
<<
2
);
float
*
outptr_base
=
out
;
float
*
outptr_base4
=
outptr_base
+
(
xmax
-
x0
)
/
12
*
ksize12
;
int
k
=
k0
;
for
(;
k
+
3
<
kmax
;
k
+=
4
)
{
const
float
*
inptr
=
in
+
k
/
PACK_C_SIZE
*
ldin
+
x0
*
PACK_C_SIZE
;
prefetch_3x
(
inptr
);
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
12
<=
xmax
;
x
+=
12
)
{
auto
outptr_interleave
=
outptr
;
transpose_1x12_4_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize12
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
transpose_1x4_4_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
std
::
memcpy
(
tmpbuff
,
inptr
,
sizeof
(
float
)
*
(
xmax
-
x
)
*
PACK_C_SIZE
);
auto
outptr_interleave
=
outptr
;
const
float
*
tmp_ptr
=
&
tmpbuff
[
0
];
transpose_1x4_4_s
<
float
>
(
tmp_ptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
outptr_base
+=
12
*
4
;
outptr_base4
+=
4
*
4
;
}
}
}
// namespace matmul_mk4_8x12
}
// aarch64
}
// megdnn
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
浏览文件 @
07d1d0ab
...
...
@@ -12,6 +12,7 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/common/utils.h"
using
namespace
megdnn
;
...
...
@@ -163,4 +164,80 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
sgemm_mk4_8x12
);
void
sgemm_mk4_8x12
::
pack_A
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose_A
)
const
{
megdnn_assert
(
!
transpose_A
,
"mk4 float matmul not support transpose A"
);
matmul_mk4_8x12
::
sgemm_8x12_pack_A
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
sgemm_mk4_8x12
::
pack_B
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose_B
)
const
{
megdnn_assert
(
!
transpose_B
,
"mk4 float matmul not support transpose B"
);
matmul_mk4_8x12
::
sgemm_8x12_pack_B
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
void
sgemm_mk4_8x12
::
kern
(
const
float
*
packA
,
const
float
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
float
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
float
*
,
float
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
C_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
MEGDNN_MARK_USED_VAR
(
A_dtype
);
MEGDNN_MARK_USED_VAR
(
B_dtype
);
MEGDNN_MARK_USED_VAR
(
C_dtype
);
megdnn_assert
(
M
%
4
==
0
&&
K
%
4
==
0
,
"M and K must be time of 4"
);
constexpr
size_t
PACK_C_SIZE
=
4
;
constexpr
size_t
A_INTERLEAVE
=
8
;
constexpr
size_t
A_INTERLEAVE4
=
4
;
constexpr
size_t
B_INTERLEAVE
=
12
;
const
int
K12
=
K
*
12
;
const
int
K8
=
K
*
8
;
const
int
K4
=
K
*
4
;
size_t
m
=
0
;
for
(;
m
+
A_INTERLEAVE
<=
M
;
m
+=
A_INTERLEAVE
)
{
float
*
output
=
C
+
(
m
/
PACK_C_SIZE
*
LDC
);
size_t
n
=
0
;
const
float
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE
<=
N
;
n
+=
B_INTERLEAVE
)
{
matmul_mk4_8x12
::
kern_8x12
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
B_INTERLEAVE
*
PACK_C_SIZE
;
cur_packB
+=
K12
;
}
for
(;
n
<
N
;
n
+=
4
)
{
matmul_mk4_8x12
::
kern_8x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
N
-
n
,
4
));
output
+=
4
*
PACK_C_SIZE
;
cur_packB
+=
K4
;
}
packA
+=
K8
;
}
for
(;
m
<
M
;
m
+=
A_INTERLEAVE4
)
{
float
*
output
=
C
+
(
m
/
PACK_C_SIZE
*
LDC
);
size_t
n
=
0
;
const
float
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_mk4_8x12
::
kern_4x12
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
B_INTERLEAVE
*
PACK_C_SIZE
;
cur_packB
+=
K12
;
}
for
(;
n
<
N
;
n
+=
4
)
{
matmul_mk4_8x12
::
kern_4x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
N
-
n
,
4
));
output
+=
4
*
PACK_C_SIZE
;
cur_packB
+=
K4
;
}
packA
+=
K4
;
}
}
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/fp32/strategy.h
浏览文件 @
07d1d0ab
...
...
@@ -20,6 +20,9 @@ MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
16
,
1
,
false
,
true
,
sgemm_4x16
);
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
8
,
12
,
1
,
false
,
false
,
sgemm_mk4_8x12
);
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
4
,
16
,
1
,
false
,
true
,
sgemm_nopack_4x16
);
...
...
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
07d1d0ab
...
...
@@ -18,6 +18,7 @@ using namespace aarch64;
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoF32K8x12x1
f32K8x12x1
;
AlgoF32MK4_8x12x1
f32_mk4_8x12x1
;
AlgoF32K4x16x1
f32k4x16x1
;
AlgoF32MK4_4x16
f32mk4_4x16
;
AlgoF32Gemv
f32_gemv
;
...
...
@@ -53,6 +54,7 @@ public:
AlgoPack
()
{
all_algos
.
emplace_back
(
&
f32_gemv
);
all_algos
.
emplace_back
(
&
f32K8x12x1
);
all_algos
.
emplace_back
(
&
f32_mk4_8x12x1
);
all_algos
.
emplace_back
(
&
f32k4x16x1
);
all_algos
.
emplace_back
(
&
f32mk4_4x16
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
07d1d0ab
...
...
@@ -22,6 +22,7 @@ public:
private:
class
AlgoF32K8x12x1
;
// Aarch64 F32 Kernel 8X12X1
class
AlgoF32MK4_8x12x1
;
// Aarch64 F32 Kernel MK4 8x12x1
class
AlgoF32K4x16x1
;
// Aarch64 F32 Kernel 4x16x1
class
AlgoF32MK4_4x16
;
// Aarch64 F32 Format MK4 block 16x4
class
AlgoF32Gemv
;
// Aarch64 F32 Gemv
...
...
dnn/src/arm_common/conv_bias/f16/algos.cpp
浏览文件 @
07d1d0ab
...
...
@@ -244,6 +244,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
if
(
param
.
filter_meta
.
icpg
%
8
!=
0
||
param
.
filter_meta
.
ocpg
%
8
!=
0
)
return
false
;
using
Strategy
=
winograd
::
winograd_2x3_8x8_f16
;
using
PackMode
=
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
;
Strategy
strategy
(
param
.
src_type
,
param
.
filter_type
,
param
.
dst_type
);
auto
&&
matmul_param
=
megdnn
::
winograd
::
ConvBias
<
Strategy
,
...
...
@@ -252,6 +253,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable(
param
.
osz
[
1
],
param
.
filter_meta
.
ocpg
)
.
get_matmul_kern_param
(
param
);
return
m_matmul_algo
->
usable
(
matmul_param
)
&&
m_matmul_algo
->
packmode
()
==
PackMode
::
NO_PACK
&&
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW
||
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
...
...
dnn/src/arm_common/conv_bias/fp32/algos.cpp
浏览文件 @
07d1d0ab
...
...
@@ -38,6 +38,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
if
(
param
.
filter_meta
.
icpg
%
4
!=
0
||
param
.
filter_meta
.
ocpg
%
4
!=
0
)
return
false
;
using
Strategy
=
winograd
::
winograd_2x3_4x4_f
;
using
PackMode
=
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
;
Strategy
strategy
(
param
.
src_type
,
param
.
filter_type
,
param
.
dst_type
);
auto
&&
matmul_param
=
megdnn
::
winograd
::
ConvBias
<
Strategy
,
...
...
@@ -46,6 +47,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
param
.
osz
[
1
],
param
.
filter_meta
.
ocpg
)
.
get_matmul_kern_param
(
param
);
return
m_matmul_algo
->
usable
(
matmul_param
)
&&
m_matmul_algo
->
packmode
()
==
PackMode
::
NO_PACK
&&
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW
||
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
...
...
@@ -319,6 +321,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
if
(
param
.
filter_meta
.
icpg
%
4
!=
0
||
param
.
filter_meta
.
ocpg
%
4
!=
0
)
return
false
;
using
Strategy
=
winograd
::
winograd_6x3_4x4_f
;
using
PackMode
=
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
;
Strategy
strategy
(
param
.
src_type
,
param
.
filter_type
,
param
.
dst_type
);
auto
&&
matmul_param
=
megdnn
::
winograd
::
ConvBias
<
Strategy
,
...
...
@@ -327,6 +330,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
param
.
osz
[
1
],
param
.
filter_meta
.
ocpg
)
.
get_matmul_kern_param
(
param
);
return
m_matmul_algo
->
usable
(
matmul_param
)
&&
m_matmul_algo
->
packmode
()
==
PackMode
::
NO_PACK
&&
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW
||
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
...
...
dnn/src/arm_common/conv_bias/int8/algos.cpp
浏览文件 @
07d1d0ab
...
...
@@ -217,6 +217,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
if
(
param
.
filter_meta
.
icpg
%
8
!=
0
||
param
.
filter_meta
.
ocpg
%
8
!=
0
)
return
false
;
using
Strategy
=
winograd
::
winograd_2x3_8x8_s8
;
using
PackMode
=
fallback
::
MatrixMulImpl
::
AlgoBase
::
PackMode
;
Strategy
strategy
(
param
.
src_type
,
param
.
filter_type
,
param
.
dst_type
);
auto
&&
matmul_param
=
megdnn
::
winograd
::
ConvBias
<
Strategy
,
param
::
MatrixMul
::
Format
::
MK8
>
(
...
...
@@ -224,6 +225,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
param
.
osz
[
1
],
param
.
filter_meta
.
ocpg
)
.
get_matmul_kern_param
(
param
);
return
m_matmul_algo
->
usable
(
matmul_param
)
&&
m_matmul_algo
->
packmode
()
==
PackMode
::
NO_PACK
&&
((
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
||
(
opr
->
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
&&
...
...
dnn/test/aarch64/matrix_mul.cpp
浏览文件 @
07d1d0ab
...
...
@@ -31,6 +31,12 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) {
"AARCH64_F32K4X16X1"
);
}
TEST_F
(
AARCH64
,
MATRIX_MUL_FP32_PACK_MK4
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
"AARCH64_F32_MK4_K8X12X1"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
}
TEST_F
(
AARCH64
,
MATRIX_MUL_FP32_MK4
)
{
//! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul
::
check_matrix_mul
(
...
...
@@ -527,6 +533,15 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_MK4) {
dtype
::
Float32
{});
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_FP32_PACK_MK4
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
16
);
matrix_mul
::
benchmark_with_contrast
(
handle
(),
args
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
"AARCH64_F32_MK4_K8X12X1"
,
param
::
MatrixMul
::
Format
::
MK4
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
"AARCH64_F32K8X12X1"
);
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_INT16x16x32_MK8
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
8
);
matrix_mul
::
benchmark_with_contrast
(
...
...
dnn/test/common/matrix_mul.cpp
浏览文件 @
07d1d0ab
...
...
@@ -40,8 +40,8 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args(
size_t
nbase
)
{
std
::
vector
<
TestArg
>
args
;
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
})
for
(
size_t
n
:
{
1
,
2
,
3
,
4
,
5
,
8
,
16
,
24
})
for
(
size_t
k
:
{
1
,
2
,
3
,
4
,
5
})
for
(
size_t
n
:
{
1
,
2
,
3
,
4
,
5
,
8
,
1
2
,
1
6
,
24
})
for
(
size_t
k
:
{
1
,
2
,
3
,
4
,
5
,
9
,
10
})
args
.
emplace_back
(
m
,
n
*
nbase
,
k
,
0
);
return
args
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录