Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
34659c2e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
34659c2e
编写于
6月 12, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/dnn): remove armv7 matmul mk4dot block 8x6
GitOrigin-RevId: 4c746ef22895ebc6ab4298c66f71e239b437cc69
上级
48ac1e1a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
42 addition
and
405 deletion
+42
-405
dnn/src/armv7/matrix_mul/algos.cpp
dnn/src/armv7/matrix_mul/algos.cpp
+15
-15
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+2
-2
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
+6
-352
dnn/src/armv7/matrix_mul/int8/strategy.cpp
dnn/src/armv7/matrix_mul/int8/strategy.cpp
+12
-29
dnn/src/armv7/matrix_mul/int8/strategy.h
dnn/src/armv7/matrix_mul/int8/strategy.h
+2
-2
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+2
-2
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+1
-1
dnn/test/armv7/matrix_mul.cpp
dnn/test/armv7/matrix_mul.cpp
+2
-2
未找到文件。
dnn/src/armv7/matrix_mul/algos.cpp
浏览文件 @
34659c2e
...
...
@@ -707,11 +707,11 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
armv7
::
matmul
::
gemm_dot_quint8_4x8
,
uint8_t
,
int32_t
);
/* ======================== Int8 MK4 8x
6
x4 dot algo ======================== */
/* ======================== Int8 MK4 8x
4
x4 dot algo ======================== */
namespace
{
void
int8_mk4_8x
6
x4_dotprod_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
void
int8_mk4_8x
4
x4_dotprod_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"int8_mk4_8x
6
x4_dotprod_kern"
_hash
))
{
midout_iv
(
"int8_mk4_8x
4
x4_dotprod_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
...
...
@@ -720,9 +720,9 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
const
auto
Aptr
=
kern_param
.
A
<
dt_int8
>
(),
Bptr
=
kern_param
.
B
<
dt_int8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int32
>
();
armv7
::
matmul
::
gemm_mk4_dots8_8x
6
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
armv7
::
matmul
::
gemm_mk4_dots8_8x
4
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_mk4_dots8_8x
6
>
(
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_mk4_dots8_8x
4
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
...
...
@@ -731,7 +731,7 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
}
}
// namespace
bool
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
6
x4DotProd
::
usable
(
bool
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
4
x4DotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
()
&&
(
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
||
...
...
@@ -743,35 +743,35 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable(
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
;
}
size_t
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
6
x4DotProd
::
get_workspace
(
size_t
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
4
x4DotProd
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"AlgoInt8x8x32MK4_8x
6
x4DotProd::get_workspace"
_hash
))
{
midout_iv
(
"AlgoInt8x8x32MK4_8x
4
x4DotProd::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
armv7
::
matmul
::
gemm_mk4_dots8_8x
6
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
armv7
::
matmul
::
gemm_mk4_dots8_8x
4
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_mk4_dots8_8x
6
>
(
M
,
N
,
K
,
trA
,
trB
,
armv7
::
matmul
::
gemm_mk4_dots8_8x
4
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
6
x4DotProd
::
get_kern
(
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
4
x4DotProd
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
int8_mk4_8x
6
x4_dotprod_kern
;
return
int8_mk4_8x
4
x4_dotprod_kern
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt8x8x32MK4_8x
6
x4DotProd
,
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt8x8x32MK4_8x
4
x4DotProd
,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32MK4_8x
6
x4DotProd"
_hash
,
armv7
::
matmul
::
gemm_mk4_dots8_8x
6
,
int8_t
,
"AlgoInt8x8x32MK4_8x
4
x4DotProd"
_hash
,
armv7
::
matmul
::
gemm_mk4_dots8_8x
4
,
int8_t
,
int32_t
);
#endif
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
34659c2e
...
...
@@ -94,11 +94,11 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
6
x4DotProd
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x
4
x4DotProd
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH32_INT8_MK4_8X
6
X4_DOTPROD"
;
return
"AARCH32_INT8_MK4_8X
4
X4_DOTPROD"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
...
...
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
6
x4.h
→
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
4
x4.h
浏览文件 @
34659c2e
/**
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
6
x4.h
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
4
x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -17,205 +17,7 @@
namespace
megdnn
{
namespace
armv7
{
namespace
matmul_mk4_dot_8x6x4
{
// Overview of register layout:
//
// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1.
// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3
// A 2x6x4 block of accumulators is stored in 8bit in q4-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q2[0-16]| | q4[0-4]|
// | q3[0-16]| | q5[0-4]|
// +---------+ | q6[0-4]|
// | q7[0-4]|
// | q8[0-4]|
// | q9[0-4]|
// |q10[0-4]|
// |q11[0-4]|
// |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static
void
kern_8x6
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
const
int8_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int
oddk
=
(
K
&
1
);
int
k
=
(
K
+
1
)
/
2
-
1
;
LDC
=
LDC
*
sizeof
(
int32_t
);
int32_t
*
outptr0
=
output
;
int32_t
*
outptr1
;
asm
volatile
(
// load accumulator C
"add %[outptr1], %[outptr0], %[LDC]
\n
"
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"vld1.32 {d8, d9}, [%[outptr0]]!
\n
"
"vld1.32 {d10, d11}, [%[outptr0]]!
\n
"
"vld1.32 {d12, d13}, [%[outptr0]]!
\n
"
"vld1.32 {d14, d15}, [%[outptr0]]!
\n
"
"vld1.32 {d16, d17}, [%[outptr0]]!
\n
"
"vld1.32 {d18, d19}, [%[outptr0]]!
\n
"
"vld1.32 {d20, d21}, [%[outptr1]]!
\n
"
"vld1.32 {d22, d23}, [%[outptr1]]!
\n
"
"vld1.32 {d24, d25}, [%[outptr1]]!
\n
"
"vld1.32 {d26, d27}, [%[outptr1]]!
\n
"
"vld1.32 {d28, d29}, [%[outptr1]]!
\n
"
"vld1.32 {d30, d31}, [%[outptr1]]!
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.s32 q4, q4, q4
\n
"
"veor.s32 q5, q5, q5
\n
"
"veor.s32 q6, q6, q6
\n
"
"veor.s32 q7, q7, q7
\n
"
"veor.s32 q8, q8, q8
\n
"
"veor.s32 q9, q9, q9
\n
"
"veor.s32 q10, q10, q10
\n
"
"veor.s32 q11, q11, q11
\n
"
"veor.s32 q12, q12, q12
\n
"
"veor.s32 q13, q13, q13
\n
"
"veor.s32 q14, q14, q14
\n
"
"veor.s32 q15, q15, q15
\n
"
"2:
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vld1.s8 {q2}, [%[a_ptr]]!
\n
"
"vld1.s8 {q3}, [%[a_ptr]]!
\n
"
"cmp %[k], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"vsdot.s8 q4 , q2, d0[0]
\n
"
"vsdot.s8 q5 , q2, d0[1]
\n
"
"vsdot.s8 q6 , q2, d1[0]
\n
"
"vsdot.s8 q7 , q2, d1[1]
\n
"
"vsdot.s8 q8 , q2, d2[0]
\n
"
"vsdot.s8 q9 , q2, d2[1]
\n
"
"vsdot.s8 q10 , q3, d0[0]
\n
"
"vsdot.s8 q11 , q3, d0[1]
\n
"
"vsdot.s8 q12 , q3, d1[0]
\n
"
"vsdot.s8 q13 , q3, d1[1]
\n
"
"vsdot.s8 q14 , q3, d2[0]
\n
"
"vsdot.s8 q15 , q3, d2[1]
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vld1.s8 {q2}, [%[a_ptr]]!
\n
"
"vld1.s8 {q3}, [%[a_ptr]]!
\n
"
"vsdot.s8 q4 , q2, d0[0]
\n
"
"vsdot.s8 q5 , q2, d0[1]
\n
"
"vsdot.s8 q6 , q2, d1[0]
\n
"
"vsdot.s8 q7 , q2, d1[1]
\n
"
"vsdot.s8 q8 , q2, d2[0]
\n
"
"vsdot.s8 q9 , q2, d2[1]
\n
"
"vsdot.s8 q10 , q3, d0[0]
\n
"
"vsdot.s8 q11 , q3, d0[1]
\n
"
"vsdot.s8 q12 , q3, d1[0]
\n
"
"vsdot.s8 q13 , q3, d1[1]
\n
"
"vsdot.s8 q14 , q3, d2[0]
\n
"
"vsdot.s8 q15 , q3, d2[1]
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vld1.s8 {q2}, [%[a_ptr]]!
\n
"
"vld1.s8 {q3}, [%[a_ptr]]!
\n
"
"subs %[k], %[k], #1
\n
"
"bne 3b
\n
"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:
\n
"
"cmp %[oddk], #0
\n
"
"bne 5f
\n
"
"vsdot.s8 q4 , q2, d0[0]
\n
"
"vsdot.s8 q5 , q2, d0[1]
\n
"
"vsdot.s8 q6 , q2, d1[0]
\n
"
"vsdot.s8 q7 , q2, d1[1]
\n
"
"vsdot.s8 q8 , q2, d2[0]
\n
"
"vsdot.s8 q9 , q2, d2[1]
\n
"
"vsdot.s8 q10 , q3, d0[0]
\n
"
"vsdot.s8 q11 , q3, d0[1]
\n
"
"vsdot.s8 q12 , q3, d1[0]
\n
"
"vsdot.s8 q13 , q3, d1[1]
\n
"
"vsdot.s8 q14 , q3, d2[0]
\n
"
"vsdot.s8 q15 , q3, d2[1]
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vld1.s8 {q2}, [%[a_ptr]]!
\n
"
"vld1.s8 {q3}, [%[a_ptr]]!
\n
"
"vsdot.s8 q4 , q2, d0[0]
\n
"
"vsdot.s8 q5 , q2, d0[1]
\n
"
"vsdot.s8 q6 , q2, d1[0]
\n
"
"vst1.32 {d8, d9}, [%[outptr0]]!
\n
"
"vsdot.s8 q7 , q2, d1[1]
\n
"
"vsdot.s8 q8 , q2, d2[0]
\n
"
"vsdot.s8 q9 , q2, d2[1]
\n
"
"vst1.32 {d10, d11}, [%[outptr0]]!
\n
"
"vsdot.s8 q10 , q3, d0[0]
\n
"
"vsdot.s8 q11 , q3, d0[1]
\n
"
"vsdot.s8 q12 , q3, d1[0]
\n
"
"vst1.32 {d12, d13}, [%[outptr0]]!
\n
"
"vsdot.s8 q13 , q3, d1[1]
\n
"
"vsdot.s8 q14 , q3, d2[0]
\n
"
"vsdot.s8 q15 , q3, d2[1]
\n
"
"b 6f
\n
"
"5:
\n
"
"vsdot.s8 q4 , q2, d0[0]
\n
"
"vsdot.s8 q5 , q2, d0[1]
\n
"
"vsdot.s8 q6 , q2, d1[0]
\n
"
"vst1.32 {d8, d9}, [%[outptr0]]!
\n
"
"vsdot.s8 q7 , q2, d1[1]
\n
"
"vsdot.s8 q8 , q2, d2[0]
\n
"
"vsdot.s8 q9 , q2, d2[1]
\n
"
"vst1.32 {d10, d11}, [%[outptr0]]!
\n
"
"vsdot.s8 q10 , q3, d0[0]
\n
"
"vsdot.s8 q11 , q3, d0[1]
\n
"
"vsdot.s8 q12 , q3, d1[0]
\n
"
"vst1.32 {d12, d13}, [%[outptr0]]!
\n
"
"vsdot.s8 q13 , q3, d1[1]
\n
"
"vsdot.s8 q14 , q3, d2[0]
\n
"
"vsdot.s8 q15 , q3, d2[1]
\n
"
"6:
\n
"
"vst1.32 {d14, d15}, [%[outptr0]]!
\n
"
"vst1.32 {d16, d17}, [%[outptr0]]!
\n
"
"vst1.32 {d18, d19}, [%[outptr0]]!
\n
"
"vst1.32 {d20, d21}, [%[outptr1]]!
\n
"
"vst1.32 {d22, d23}, [%[outptr1]]!
\n
"
"vst1.32 {d24, d25}, [%[outptr1]]!
\n
"
"vst1.32 {d26, d27}, [%[outptr1]]!
\n
"
"vst1.32 {d28, d29}, [%[outptr1]]!
\n
"
"vst1.32 {d30, d31}, [%[outptr1]]!
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
LDC
]
"+r"
(
LDC
),
[
oddk
]
"+r"
(
oddk
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
k
]
"+r"
(
k
),
[
outptr0
]
"+r"
(
outptr0
),
[
outptr1
]
"+r"
(
outptr1
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
namespace
matmul_mk4_dot_8x4x4
{
// Overview of register layout:
//
...
...
@@ -390,144 +192,6 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}
// Overview of register layout:
//
// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3.
// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5
// A 2x6x4 block of accumulators is stored in 8bit in q10-q15
//
// +--------+
// Rhs |q0[0-16]|
// |q1[0-16]|
// +--------+
// Lhs | |
// +-------+-------+ - - - - +--------+
// | q4[0-16]| |q10[0-4]|
// | q5[0-16]| |q11[0-4]|
// +---------+ |q12[0-4]|
// |q13[0-4]|
// |q14[0-4]|
// |q15[0-4]|
// +--------+
// Accumulator
static
void
kern_4x6
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int32_t
*
output
,
int
LDC
,
bool
is_first_k
)
{
K
/=
4
;
const
int8_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
// Fix up for odd lengths - set a flag if K is odd, but make
// sure we round up the iteration count.
int
oddk
=
(
K
&
1
);
int
k
=
(
K
+
1
)
/
2
-
1
;
LDC
=
LDC
*
sizeof
(
int32_t
);
int32_t
*
outptr0
=
output
;
asm
volatile
(
// load accumulator C
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"vld1.32 {d20, d21}, [%[outptr0]]!
\n
"
"vld1.32 {d22, d23}, [%[outptr0]]!
\n
"
"vld1.32 {d24, d25}, [%[outptr0]]!
\n
"
"vld1.32 {d26, d27}, [%[outptr0]]!
\n
"
"vld1.32 {d28, d29}, [%[outptr0]]!
\n
"
"vld1.32 {d30, d31}, [%[outptr0]]!
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.s32 q10, q10, q10
\n
"
"veor.s32 q11, q11, q11
\n
"
"veor.s32 q12, q12, q12
\n
"
"veor.s32 q13, q13, q13
\n
"
"veor.s32 q14, q14, q14
\n
"
"veor.s32 q15, q15, q15
\n
"
"2:
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vld1.s8 {q4}, [%[a_ptr]]!
\n
"
"cmp %[k], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"vsdot.s8 q10 , q4, d0[0]
\n
"
"vsdot.s8 q11 , q4, d0[1]
\n
"
"vsdot.s8 q12 , q4, d1[0]
\n
"
"vld1.s8 {q2}, [%[b_ptr]]!
\n
"
"vld1.s8 {d6}, [%[b_ptr]]!
\n
"
"vld1.s8 {q5}, [%[a_ptr]]!
\n
"
"vsdot.s8 q13 , q4, d1[1]
\n
"
"vsdot.s8 q14 , q4, d2[0]
\n
"
"vsdot.s8 q15 , q4, d2[1]
\n
"
"vld1.s8 {q0}, [%[b_ptr]]!
\n
"
"vsdot.s8 q10 , q5, d4[0]
\n
"
"vsdot.s8 q11 , q5, d4[1]
\n
"
"vsdot.s8 q12 , q5, d5[0]
\n
"
"vld1.s8 {d2}, [%[b_ptr]]!
\n
"
"vsdot.s8 q13 , q5, d5[1]
\n
"
"vsdot.s8 q14 , q5, d6[0]
\n
"
"vsdot.s8 q15 , q5, d6[1]
\n
"
"vld1.s8 {q4}, [%[a_ptr]]!
\n
"
"subs %[k], %[k], #1
\n
"
"bne 3b
\n
"
// Target to use when K is 1 or 2 (i.e. zero iterations of main
// loop)
"4:
\n
"
"cmp %[oddk], #0
\n
"
"bne 5f
\n
"
"vsdot.s8 q10 , q4, d0[0]
\n
"
"vsdot.s8 q11 , q4, d0[1]
\n
"
"vsdot.s8 q12 , q4, d1[0]
\n
"
"vld1.s8 {q2}, [%[b_ptr]]!
\n
"
"vld1.s8 {d6}, [%[b_ptr]]!
\n
"
"vld1.s8 {q5}, [%[a_ptr]]!
\n
"
"vsdot.s8 q13 , q4, d1[1]
\n
"
"vsdot.s8 q14 , q4, d2[0]
\n
"
"vsdot.s8 q15 , q4, d2[1]
\n
"
"vsdot.s8 q10 , q5, d4[0]
\n
"
"vsdot.s8 q11 , q5, d4[1]
\n
"
"vsdot.s8 q12 , q5, d5[0]
\n
"
"vst1.32 {d20, d21}, [%[outptr0]]!
\n
"
"vsdot.s8 q13 , q5, d5[1]
\n
"
"vsdot.s8 q14 , q5, d6[0]
\n
"
"vsdot.s8 q15 , q5, d6[1]
\n
"
"vst1.32 {d22, d23}, [%[outptr0]]!
\n
"
"b 6f
\n
"
"5:
\n
"
"vsdot.s8 q10 , q4, d0[0]
\n
"
"vsdot.s8 q11 , q4, d0[1]
\n
"
"vsdot.s8 q12 , q4, d1[0]
\n
"
"vst1.32 {d20, d21}, [%[outptr0]]!
\n
"
"vsdot.s8 q13 , q4, d1[1]
\n
"
"vsdot.s8 q14 , q4, d2[0]
\n
"
"vsdot.s8 q15 , q4, d2[1]
\n
"
"vst1.32 {d22, d23}, [%[outptr0]]!
\n
"
"6:
\n
"
"vst1.32 {d24, d25}, [%[outptr0]]!
\n
"
"vst1.32 {d26, d27}, [%[outptr0]]!
\n
"
"vst1.32 {d28, d29}, [%[outptr0]]!
\n
"
"vst1.32 {d30, d31}, [%[outptr0]]!
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
LDC
]
"+r"
(
LDC
),
[
oddk
]
"+r"
(
oddk
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
k
]
"+r"
(
k
),
[
outptr0
]
"+r"
(
outptr0
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
}
// Overview of register layout:
//
// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3.
...
...
@@ -671,7 +335,7 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}
static
void
gemm_dots8_8x
6
_pack_A
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
static
void
gemm_dots8_8x
4
_pack_A
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
int
y
=
y0
,
y_start
=
y0
/
4
;
...
...
@@ -692,14 +356,12 @@ static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr,
}
}
static
void
gemm_dots8_8x
6
_pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
static
void
gemm_dots8_8x
4
_pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
const
int
ksize
=
kmax
-
k0
;
const
int
ksize4
=
ksize
*
4
;
const
int
ksize6
=
ksize
*
6
;
int8_t
*
outptr
=
out
;
int8_t
*
outptr_base
=
out
;
int8_t
*
outptr_base4
=
out
+
((
xmax
-
x0
)
/
6
)
*
ksize6
;
int
k
=
k0
;
for
(;
k
+
3
<
kmax
;
k
+=
4
)
{
...
...
@@ -708,13 +370,6 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
outptr
=
outptr_base
;
int
x
=
x0
;
for
(;
x
+
5
<
xmax
;
x
+=
6
)
{
memcpy
(
outptr
,
inptr
,
sizeof
(
int8_t
)
*
24
);
outptr
+=
ksize6
;
inptr
+=
24
;
}
outptr
=
outptr_base4
;
for
(;
x
+
3
<
xmax
;
x
+=
4
)
{
memcpy
(
outptr
,
inptr
,
sizeof
(
int8_t
)
*
16
);
outptr
+=
ksize4
;
...
...
@@ -735,12 +390,11 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*
outptr
++
=
0
;
}
}
outptr_base
+=
24
;
outptr_base4
+=
16
;
outptr_base
+=
16
;
}
}
}
// namespace matmul_mk4_dot_8x
6
x4
}
// namespace matmul_mk4_dot_8x
4
x4
}
// namespace armv7
}
// namespace megdnn
#endif
...
...
dnn/src/armv7/matrix_mul/int8/strategy.cpp
浏览文件 @
34659c2e
...
...
@@ -16,7 +16,7 @@
#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
6
x4.h"
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x
4
x4.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
...
...
@@ -254,10 +254,10 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
}
}
// ===========================gemm_mk4_dots8_8x
6
======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_mk4_dots8_8x
6
);
// ===========================gemm_mk4_dots8_8x
4
======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_mk4_dots8_8x
4
);
void
gemm_mk4_dots8_8x
6
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
void
gemm_mk4_dots8_8x
4
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
megdnn_assert
(
!
transpose
,
...
...
@@ -266,49 +266,39 @@ void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin,
"mk4 format matmul with m is not times of 4."
);
megdnn_assert
(
kmax
%
4
==
0
&&
k0
%
4
==
0
,
"mk4 format matmul with k is not times of 4."
);
matmul_mk4_dot_8x
6x4
::
gemm_dots8_8x6
_pack_A
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
matmul_mk4_dot_8x
4x4
::
gemm_dots8_8x4
_pack_A
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
gemm_mk4_dots8_8x
6
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
void
gemm_mk4_dots8_8x
4
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
megdnn_assert
(
!
transpose
,
"matrix mul mk4 with transposed matrix B is not supported"
);
megdnn_assert
(
kmax
%
4
==
0
&&
k0
%
4
==
0
,
"mk4 format matmul with k is not times of 4."
);
matmul_mk4_dot_8x
6x4
::
gemm_dots8_8x6
_pack_B
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
matmul_mk4_dot_8x
4x4
::
gemm_dots8_8x4
_pack_B
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
void
gemm_mk4_dots8_8x
6
::
kern
(
const
dt_int8
*
packA
,
const
dt_int8
*
packB
,
void
gemm_mk4_dots8_8x
4
::
kern
(
const
dt_int8
*
packA
,
const
dt_int8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_int32
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_int32
*
bias
,
dt_int32
*
workspace
)
const
{
MEGDNN_MARK_USED_VAR
(
bias
);
constexpr
size_t
A_INTERLEAVE
=
8
;
constexpr
size_t
B_INTERLEAVE
=
6
;
//! K is packed to times of 4
K
=
round_up
<
size_t
>
(
K
,
4
);
const
int
K4
=
K
*
4
;
const
int
K6
=
K
*
6
;
const
int
K8
=
K
*
8
;
size_t
m
=
0
;
for
(;
m
+
A_INTERLEAVE
-
1
<
M
;
m
+=
A_INTERLEAVE
)
{
int32_t
*
output
=
C
+
((
m
>>
2
)
*
LDC
);
const
dt_int8
*
cur_packB
=
packB
;
size_t
n
=
0
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_mk4_dot_8x6x4
::
kern_8x6
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
24
;
cur_packB
+=
K6
;
}
for
(;
n
<
N
;
n
+=
4
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
+=
4
)
{
size_t
n_remain
=
std
::
min
<
size_t
>
(
N
-
n
,
4
);
matmul_mk4_dot_8x
6
x4
::
kern_8x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
matmul_mk4_dot_8x
4
x4
::
kern_8x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
n_remain
);
output
+=
16
;
cur_packB
+=
K4
;
...
...
@@ -318,16 +308,9 @@ void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB,
for
(;
m
<
M
;
m
+=
4
)
{
int32_t
*
output
=
C
+
((
m
>>
2
)
*
LDC
);
const
dt_int8
*
cur_packB
=
packB
;
size_t
n
=
0
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_mk4_dot_8x6x4
::
kern_4x6
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
24
;
cur_packB
+=
K6
;
}
for
(;
n
<
N
;
n
+=
4
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
+=
4
)
{
size_t
n_remain
=
std
::
min
<
size_t
>
(
N
-
n
,
4
);
matmul_mk4_dot_8x
6
x4
::
kern_4x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
matmul_mk4_dot_8x
4
x4
::
kern_4x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
n_remain
);
output
+=
16
;
cur_packB
+=
K4
;
...
...
dnn/src/armv7/matrix_mul/int8/strategy.h
浏览文件 @
34659c2e
...
...
@@ -27,8 +27,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false,
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int32
,
dt_int32
,
6
,
8
,
4
,
false
,
false
,
gemm_dots8_6x8
);
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int32
,
dt_int32
,
8
,
6
,
4
,
false
,
false
,
gemm_mk4_dots8_8x
6
);
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int32
,
dt_int32
,
8
,
4
,
4
,
false
,
false
,
gemm_mk4_dots8_8x
4
);
#endif
}
// namespace matmul
}
// namespace armv7
...
...
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
34659c2e
...
...
@@ -29,7 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K6x8x4
int8_k6x8x4
;
AlgoQuint8DotK4x8x4
quint8_k4x8x4
;
AlgoInt8x8x32MK4_8x
6x4DotProd
int8x8x32_mk4_8x6
x4_dotprod
;
AlgoInt8x8x32MK4_8x
4x4DotProd
int8x8x32_mk4_8x4
x4_dotprod
;
#endif
AlgoF32Gemv
f32_gemv
;
AlgoInt8x8x32MK4_4x2x16
int8x8x32_mk4_4x2x16
;
...
...
@@ -57,7 +57,7 @@ public:
all_algos
.
emplace_back
(
&
f16_mk8_4x8
);
#endif
#if __ARM_FEATURE_DOTPROD
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x
6
x4_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x
4
x4_dotprod
);
all_algos
.
emplace_back
(
&
int8_k6x8x4
);
all_algos
.
emplace_back
(
&
quint8_k4x8x4
);
#endif
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
34659c2e
...
...
@@ -42,7 +42,7 @@ private:
#if __ARM_FEATURE_DOTPROD
class
AlgoInt8x8x32K6x8x4
;
// Armv7 Int8 Kernel 6x8x4
class
AlgoQuint8DotK4x8x4
;
// Armv7 Quint8 Kernel 6x8x4
class
AlgoInt8x8x32MK4_8x
6x4DotProd
;
// Armv7 nchw44 Int8x8x32 Kernel 8x6
x4
class
AlgoInt8x8x32MK4_8x
4x4DotProd
;
// Armv7 nchw44 Int8x8x32 Kernel 8x4
x4
// DotProduct
#endif
class
AlgoPack
;
...
...
dnn/test/armv7/matrix_mul.cpp
浏览文件 @
34659c2e
...
...
@@ -94,7 +94,7 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
for
(
size_t
k
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
16
,
32
,
33
,
34
})
args
.
emplace_back
(
m
,
n
,
k
,
0
);
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int32
{},
handle
(),
"AARCH32_INT8_MK4_8X
6
X4_DOTPROD"
,
handle
(),
"AARCH32_INT8_MK4_8X
4
X4_DOTPROD"
,
param
::
MatrixMul
::
Format
::
MK4_DOT
,
1
,
1e-3
,
std
::
move
(
args
));
}
...
...
@@ -315,7 +315,7 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) {
param
.
format
=
MatrixMul
::
Param
::
Format
::
MK4_DOT
;
Benchmarker
<
MatrixMul
>
benchmarker_mk4_dot
(
handle
());
benchmarker_mk4_dot
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH32_INT8_MK4_8X
6
X4_DOTPROD"
));
AlgoChecker
<
MatrixMul
>
(
"AARCH32_INT8_MK4_8X
4
X4_DOTPROD"
));
benchmarker_mk4_dot
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Int8
())
.
set_dtype
(
1
,
dtype
::
Int8
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录