Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
eed54081
MegEngine
项目概览
MegEngine 天元
/
MegEngine
10 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
eed54081
编写于
7月 23, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): add armv7 mk4 i8i8i16 gemm, optimized for A7
GitOrigin-RevId: d2f8290a8d6577b99adad16e42d57a6ca55a119e
上级
9c475fff
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
890 addition
and
152 deletion
+890
-152
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+8
-3
dnn/src/armv7/matrix_mul/algos.cpp
dnn/src/armv7/matrix_mul/algos.cpp
+73
-6
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+14
-1
dnn/src/armv7/matrix_mul/asm/common.h
dnn/src/armv7/matrix_mul/asm/common.h
+95
-38
dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h
dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h
+14
-14
dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h
dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h
+406
-0
dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp
dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp
+80
-3
dnn/src/armv7/matrix_mul/int8x8x16/strategy.h
dnn/src/armv7/matrix_mul/int8x8x16/strategy.h
+6
-1
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+7
-3
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+17
-14
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
+2
-3
dnn/src/fallback/conv_bias/im2col/factory.h
dnn/src/fallback/conv_bias/im2col/factory.h
+24
-10
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+12
-9
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+48
-23
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+29
-2
dnn/test/armv7/matrix_mul.cpp
dnn/test/armv7/matrix_mul.cpp
+50
-18
dnn/test/common/matrix_mul.cpp
dnn/test/common/matrix_mul.cpp
+5
-4
未找到文件。
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/algos.h"
...
...
@@ -733,7 +734,9 @@ void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoInt8x8x16K8x8x8
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
can_be_treated_as_int8x8x16
(
kern_size_param
);
return
can_be_treated_as_int8x8x16
(
kern_size_param
)
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16K8x8x8
::
preferred
(
...
...
@@ -796,7 +799,9 @@ void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
bool
MatrixMulImpl
::
AlgoInt8x8x16K4x4x16
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
can_be_treated_as_int8x8x16
(
kern_size_param
);
return
can_be_treated_as_int8x8x16
(
kern_size_param
)
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16K4x4x16
::
preferred
(
...
...
dnn/src/armv7/matrix_mul/algos.cpp
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/algos.h"
...
...
@@ -526,6 +527,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
"AlgoInt8x8x16K4x8x8"
_hash
,
armv7
::
matmul
::
gemm_s8x8x16_4x8
,
int8_t
,
int16_t
);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
namespace
{
void
kern_int8x8x16_mk4_k8x8x4
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"kern_int8x8x16_mk4_k8x8x4"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
Aptr
=
kern_param
.
A
<
dt_int8
>
(),
Bptr
=
kern_param
.
B
<
dt_int8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int16
>
();
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
armv7
::
matmul
::
gemm_s8x8x16_mk4_8x8
strategy
(
M
,
N
,
K
,
kern_param
.
A_type
,
kern_param
.
B_type
,
kern_param
.
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_s8x8x16_mk4_8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoInt8x8x16MK4_8x8x4
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
bool
type_ok
=
can_be_treated_as_int8x8x16
(
kern_size_param
);
return
type_ok
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK4
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
M
%
4
==
0
&&
kern_size_param
.
K
%
4
==
0
;
}
size_t
MatrixMulImpl
::
AlgoInt8x8x16MK4_8x8x4
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"AlgoInt8x8x16K8x8x4::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
matmul
::
gemm_s8x8x16_mk4_8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
gemm_s8x8x16_mk4_8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x16MK4_8x8x4
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
kern_int8x8x16_mk4_k8x8x4
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x16MK4_8x8x4
::
preferred
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
K
>=
4
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16MK4_8x8x4
,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x16MK4_8x8x4"
_hash
,
armv7
::
matmul
::
gemm_s8x8x16_mk4_8x8
,
int8_t
,
int16_t
,
int16_t
);
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */
namespace
{
...
...
@@ -937,11 +1006,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern(
Bptr
=
kern_param
.
B
<
dt_float16
>
();
auto
Cptr
=
kern_param
.
C
<
dt_float16
>
();
armv7
::
matmul
::
gemm_nopack_f16_4x8
strategy
(
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_nopack_f16_4x8
,
false
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
armv7
::
matmul
::
gemm_nopack_f16_4x8
strategy
(
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
gemm_nopack_f16_4x8
,
false
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
...
...
@@ -171,6 +172,18 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt8x8x16MK4_8x8x4
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARMV7_INT8X8X16_MK4_K8X8X4"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoInt16x16x32K12x4x1
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/armv7/matrix_mul/asm/common.h
浏览文件 @
eed54081
...
...
@@ -6,13 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <arm_neon.h>
#include <cmath>
#include <cstdint>
#include <type_traits>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
...
...
@@ -172,7 +174,6 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
[
inptr6
]
"+r"
(
inptr6
),
[
inptr7
]
"+r"
(
inptr7
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"memory"
);
}
template
<
typename
T
>
...
...
@@ -183,12 +184,12 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1,
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
,
"interleave_4x4_4_b only support uint8_t and int8_t"
);
asm
volatile
(
"vld1.32 {d0, d1}, [%[inptr0]]!
\n
"
// A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!
\n
"
// B0B1B2B3
"vld1.32 {d4, d5}, [%[inptr2]]!
\n
"
// C0C1C2C3
"vld1.32 {d6, d7}, [%[inptr3]]!
\n
"
// D0D1D2D3
"vtrn.32 q0, q1
\n
"
// A0B0A2B2 A1B1A3B3
"vtrn.32 q2, q3
\n
"
// C0D0C2D2 C1D1C3D3
"vld1.32 {d0, d1}, [%[inptr0]]!
\n
"
// A0A1A2A3
"vld1.32 {d2, d3}, [%[inptr1]]!
\n
"
// B0B1B2B3
"vld1.32 {d4, d5}, [%[inptr2]]!
\n
"
// C0C1C2C3
"vld1.32 {d6, d7}, [%[inptr3]]!
\n
"
// D0D1D2D3
"vtrn.32 q0, q1
\n
"
// A0B0A2B2 A1B1A3B3
"vtrn.32 q2, q3
\n
"
// C0D0C2D2 C1D1C3D3
"vswp d1, d4
\n
"
// q0=A0,B0,C0,D0 q2=A2,B2,C2,D2
"vswp d3, d6
\n
"
// q1=A1,B1,C1,D1 q3=A3,B3,C3,D3
"vst1.32 {d0-d1},[%[outptr]]!
\n
"
...
...
@@ -323,10 +324,10 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1,
"vtrn.32 q1, q3
\n
"
// q1=r02,r12,r03,r13 q3=r06,r16,r07,r17
"vtrn.32 q5, q7
\n
"
// q5=r22,r32,r23,r33 q7=r26,r36,r27,r37
"vtrn.32 q9, q11
\n
"
// q9=r42,r52,r43,r53 q11=r46,r56,r47,r57
"vst1.32 {d0-d1}, [%[outptr]]!
\n
"
"vst1.32 {d16}, [%[outptr]]!
\n
"
"vst1.32 {d0-d1}, [%[outptr]]!
\n
"
"vst1.32 {d16}, [%[outptr]]!
\n
"
"vswp d3, d10
\n
"
// q1=r02,r12,r22,r32 q5=r03,r13,r23,r33
"vst1.32 {d8-d9}, [%[outptr]]!
\n
"
"vst1.32 {d8-d9}, [%[outptr]]!
\n
"
"vst1.32 {d17}, [%[outptr]]!
\n
"
"vst1.32 {d2-d3}, [%[outptr]]!
\n
"
"vst1.32 {d18}, [%[outptr]]!
\n
"
...
...
@@ -810,15 +811,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1,
"interleave_12x4_1_h only support uint16_t and int16_t"
);
auto
ldin_asm
=
ldin
<<
1
;
asm
volatile
(
"vld1.16 {d0}, [%[inptr0]]!
\n
"
// A0A1A2A3
"vld1.16 {d1}, [%[inptr1]]!
\n
"
// B0B1B2B3
"vld1.16 {d2}, [%[inptr2]]!
\n
"
// C0C1C2C3
"vld1.16 {d3}, [%[inptr3]]!
\n
"
// D0D1D2D3
"vld1.16 {d4}, [%[inptr4]]!
\n
"
// E0E1E2E3
"vld1.16 {d5}, [%[inptr5]]!
\n
"
// F0F1F2F3
"vld1.16 {d6}, [%[inptr6]]!
\n
"
// G0G1G2G3
"vld1.16 {d7}, [%[inptr7]]!
\n
"
// H0H1H2H3
"vld1.16 {d8}, [%[inptr8]]!
\n
"
// I0I1I2I3
"vld1.16 {d0}, [%[inptr0]]!
\n
"
// A0A1A2A3
"vld1.16 {d1}, [%[inptr1]]!
\n
"
// B0B1B2B3
"vld1.16 {d2}, [%[inptr2]]!
\n
"
// C0C1C2C3
"vld1.16 {d3}, [%[inptr3]]!
\n
"
// D0D1D2D3
"vld1.16 {d4}, [%[inptr4]]!
\n
"
// E0E1E2E3
"vld1.16 {d5}, [%[inptr5]]!
\n
"
// F0F1F2F3
"vld1.16 {d6}, [%[inptr6]]!
\n
"
// G0G1G2G3
"vld1.16 {d7}, [%[inptr7]]!
\n
"
// H0H1H2H3
"vld1.16 {d8}, [%[inptr8]]!
\n
"
// I0I1I2I3
"vld1.16 {d9}, [%[inptr9]]
\n
"
// J0J1J2J3
"add %[inptr9], %[inptr9], %[ldin_asm]
\n
"
"vld1.16 {d10}, [%[inptr9]]
\n
"
// K0K1K2K3
...
...
@@ -854,17 +855,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1,
[
inptr3
]
"+r"
(
inptr3
),
[
inptr4
]
"+r"
(
inptr4
),
[
inptr5
]
"+r"
(
inptr5
),
[
inptr6
]
"+r"
(
inptr6
),
[
inptr7
]
"+r"
(
inptr7
),
[
inptr8
]
"+r"
(
inptr8
),
[
inptr9
]
"+r"
(
inptr9
),
[
outptr
]
"+r"
(
outptr
)
:
[
ldin_asm
]
"r"
(
ldin_asm
)
:
[
ldin_asm
]
"r"
(
ldin_asm
)
:
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"d4"
,
"d5"
,
"d6"
,
"d7"
,
"d8"
,
"d9"
,
"d10"
,
"d11"
,
"memory"
);
inptr9
-=
ldin_asm
;
inptr9
+=
4
;
inptr9
-=
ldin_asm
;
inptr9
+=
4
;
inptr10
+=
4
;
inptr11
+=
4
;
}
template
<
typename
T
>
static
inline
void
transpose_2x16_1_b_helper
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1038,7 +1037,7 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1,
"vst1.32 {d7}, [%[outptr]], %[stride]
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
outptr
]
"+r"
(
outptr
),
[
stride
]
"+r"
(
stride
)
[
outptr
]
"+r"
(
outptr
),
[
stride
]
"+r"
(
stride
)
:
:
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"d4"
,
"d5"
,
"d6"
,
"d7"
,
"memory"
);
}
...
...
@@ -1069,7 +1068,6 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1,
:
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_6x4_1_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1082,9 +1080,9 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1,
"vld1.8 {d1}, [%[inptr1]]
\n
"
// B0B1B2B3B4B5 B6B7
"vld1.8 {d2}, [%[inptr2]]
\n
"
// C0C1C2C3C4C5 C6C7
"vld1.8 {d3}, [%[inptr3]]
\n
"
// D0D1D2D3D4D5 D6D7
"vtrn.8 d0, d1
\n
"
// A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3
\n
"
// C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"vtrn.8 d0, d1
\n
"
// A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3
\n
"
// C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"add %[inptr0],%[inptr0],#6
\n
"
"add %[inptr1],%[inptr1],#6
\n
"
"add %[inptr2],%[inptr2],#6
\n
"
...
...
@@ -1121,9 +1119,9 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1,
"vld1.8 {d1}, [%[inptr1]]
\n
"
// B0B1B2B3B4B5 B6B7
"vld1.8 {d2}, [%[inptr2]]
\n
"
// C0C1C2C3C4C5 C6C7
"vld1.8 {d3}, [%[inptr3]]
\n
"
// D0D1D2D3D4D5 D6D7
"vtrn.8 d0, d1
\n
"
// A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3
\n
"
// C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"vtrn.8 d0, d1
\n
"
// A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7
"vtrn.8 d2, d3
\n
"
// C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7
"add %[inptr0],%[inptr0],#4
\n
"
"add %[inptr1],%[inptr1],#4
\n
"
"add %[inptr2],%[inptr2],#4
\n
"
...
...
@@ -1176,7 +1174,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"vst1.32 {d14-d15}, [%[outptr]]!
\n
"
"vst1.32 {d22-d23}, [%[outptr]]!
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"memory"
);
...
...
@@ -1195,12 +1193,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
"vst1.32 {d4-d5}, [%[outptr]]!
\n
"
"vst1.32 {d2-d3}, [%[outptr]]!
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_4
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*
outptr
,
...
...
@@ -1251,7 +1248,6 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1,
}
}
template
<
typename
T
>
static
inline
void
transpose_4x1
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1375,7 +1371,68 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr,
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"memory"
);
}
}
// armv7
static
inline
void
interleave_4x4_8x4_s8_s16
(
const
int8_t
*
inptr0
,
const
int8_t
*
inptr1
,
int16_t
*
outptr
)
{
int8x16_t
row0
=
vld1q_s8
(
inptr0
);
int16x8_t
row0_01
=
vmovl_low_s8
(
row0
);
int16x8_t
row0_23
=
vmovl_high_s8
(
row0
);
int16x4_t
row0_0
=
vget_low_s16
(
row0_01
);
int16x4_t
row0_1
=
vget_high_s16
(
row0_01
);
int16x4_t
row0_2
=
vget_low_s16
(
row0_23
);
int16x4_t
row0_3
=
vget_high_s16
(
row0_23
);
int8x16_t
row1
=
vld1q_s8
(
inptr1
);
int16x8_t
row1_01
=
vmovl_low_s8
(
row1
);
int16x8_t
row1_23
=
vmovl_high_s8
(
row1
);
int16x4_t
row1_0
=
vget_low_s16
(
row1_01
);
int16x4_t
row1_1
=
vget_high_s16
(
row1_01
);
int16x4_t
row1_2
=
vget_low_s16
(
row1_23
);
int16x4_t
row1_3
=
vget_high_s16
(
row1_23
);
vst1_s16
(
outptr
,
row0_0
);
vst1_s16
(
outptr
+
1
*
4
,
row1_0
);
vst1_s16
(
outptr
+
2
*
4
,
row0_1
);
vst1_s16
(
outptr
+
3
*
4
,
row1_1
);
vst1_s16
(
outptr
+
4
*
4
,
row0_2
);
vst1_s16
(
outptr
+
5
*
4
,
row1_2
);
vst1_s16
(
outptr
+
6
*
4
,
row0_3
);
vst1_s16
(
outptr
+
7
*
4
,
row1_3
);
};
static
inline
void
transpos_8x4_int8
(
const
int8_t
*
inptr0
,
int8_t
*
outptr
)
{
int8x8x4_t
input
=
vld4_s8
(
inptr0
);
vst1_s8
(
outptr
,
input
.
val
[
0
]);
vst1_s8
(
outptr
+
1
*
8
,
input
.
val
[
1
]);
vst1_s8
(
outptr
+
2
*
8
,
input
.
val
[
2
]);
vst1_s8
(
outptr
+
3
*
8
,
input
.
val
[
3
]);
}
static
inline
void
memcpy_s8_s16
(
const
int8_t
*
inptr
,
int16_t
*
outptr
,
int
count
)
{
for
(;
count
>=
32
;
count
-=
32
)
{
int8x8_t
in0
=
vld1_s8
(
inptr
);
int8x8_t
in1
=
vld1_s8
(
inptr
+
1
*
8
);
int8x8_t
in2
=
vld1_s8
(
inptr
+
2
*
8
);
int8x8_t
in3
=
vld1_s8
(
inptr
+
3
*
8
);
vst1q_s16
(
outptr
,
vmovl_s8
(
in0
));
vst1q_s16
(
outptr
+
1
*
8
,
vmovl_s8
(
in1
));
vst1q_s16
(
outptr
+
2
*
8
,
vmovl_s8
(
in2
));
vst1q_s16
(
outptr
+
3
*
8
,
vmovl_s8
(
in3
));
inptr
+=
32
;
outptr
+=
32
;
}
for
(;
count
>=
8
;
count
-=
8
)
{
int8x8_t
in0
=
vld1_s8
(
inptr
);
vst1q_s16
(
outptr
,
vmovl_s8
(
in0
));
inptr
+=
8
;
outptr
+=
8
;
}
for
(;
count
>
0
;
--
count
)
{
*
outptr
++
=
(
int16_t
)(
*
inptr
++
);
}
}
}
// namespace armv7
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h
浏览文件 @
eed54081
...
...
@@ -102,60 +102,60 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"vld1.8 {d2}, [%[a_ptr]]!
\n
"
"vld1.8 {d4}, [%[a_ptr]]!
\n
"
"vld1.8 {d6}, [%[a_ptr]]!
\n
"
"vld1.8 {d18}, [%[b_ptr]]!
\n
"
"vmovl.s8 q8, d16
\n
"
"vmovl.s8 q0, d0
\n
"
"vmovl.s8 q1, d2
\n
"
"vmovl.s8 q2, d4
\n
"
"vmovl.s8 q3, d6
\n
"
"vld1.8 {d18}, [%[b_ptr]]!
\n
"
"vmovl.s8 q9, d18
\n
"
"vld1.8 {d20}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q8, d0[0]
\n
"
"vmla.s16 q5, q8, d2[0]
\n
"
"vmla.s16 q6, q8, d4[0]
\n
"
"vmla.s16 q7, q8, d6[0]
\n
"
"vmovl.s8 q9, d18
\n
"
"vld1.8 {d20}, [%[b_ptr]]!
\n
"
"vmovl.s8 q10, d20
\n
"
"vld1.8 {d22}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q9, d0[1]
\n
"
"vmla.s16 q5, q9, d2[1]
\n
"
"vmla.s16 q6, q9, d4[1]
\n
"
"vmla.s16 q7, q9, d6[1]
\n
"
"vmovl.s8 q10, d20
\n
"
"vld1.8 {d22}, [%[b_ptr]]!
\n
"
"vmovl.s8 q11, d22
\n
"
"vld1.8 {d24}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q10, d0[2]
\n
"
"vmla.s16 q5, q10, d2[2]
\n
"
"vmla.s16 q6, q10, d4[2]
\n
"
"vmla.s16 q7, q10, d6[2]
\n
"
"vmovl.s8 q11, d22
\n
"
"vld1.8 {d24}, [%[b_ptr]]!
\n
"
"vmovl.s8 q12, d24
\n
"
"vld1.8 {d26}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q11, d0[3]
\n
"
"vmla.s16 q5, q11, d2[3]
\n
"
"vmla.s16 q6, q11, d4[3]
\n
"
"vmla.s16 q7, q11, d6[3]
\n
"
"vmovl.s8 q12, d24
\n
"
"vld1.8 {d26}, [%[b_ptr]]!
\n
"
"vmovl.s8 q13, d26
\n
"
"vld1.8 {d28}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q12, d1[0]
\n
"
"vmla.s16 q5, q12, d3[0]
\n
"
"vmla.s16 q6, q12, d5[0]
\n
"
"vmla.s16 q7, q12, d7[0]
\n
"
"vmovl.s8 q13, d26
\n
"
"vld1.8 {d28}, [%[b_ptr]]!
\n
"
"vmovl.s8 q14, d28
\n
"
"vld1.8 {d30}, [%[b_ptr]]!
\n
"
"vmla.s16 q4, q13, d1[1]
\n
"
"vmla.s16 q5, q13, d3[1]
\n
"
"vmla.s16 q6, q13, d5[1]
\n
"
"vmla.s16 q7, q13, d7[1]
\n
"
"vmovl.s8 q14, d28
\n
"
"v
ld1.8 {d30}, [%[b_ptr]]!
\n
"
"v
movl.s8 q15, d30
\n
"
"vmla.s16 q4, q14, d1[2]
\n
"
"vmla.s16 q5, q14, d3[2]
\n
"
"vmla.s16 q6, q14, d5[2]
\n
"
"vmla.s16 q7, q14, d7[2]
\n
"
"vmovl.s8 q15, d30
\n
"
"vmla.s16 q4, q15, d1[3]
\n
"
"vmla.s16 q5, q15, d3[3]
\n
"
...
...
dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h
0 → 100644
浏览文件 @
eed54081
/**
* \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
namespace
megdnn
{
namespace
armv7
{
namespace
matmul_mk4_8x8x4
{
//! optimize for A7
/**
* Overview of register layout:
*
* A 8x8x8 cell of Lhs is stored in 16bit in q0, q1
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3
* A 8x8 block of accumulators is stored in 16bit in q8-q15
*
* +--------+
* | q4[0-8]|
* Rhs +--------+
* Lhs | |
*
* +--------+ - - - - +---------
* |q0[0]| | q8 [0-8]|
* |q0[1]| | q9 [0-8]|
* |q0[2]| | q10[0-8]|
* |q0[3]| | q11[0-8]|
* |q0[4]| | q12[0-8]|
* |q0[5]| | q13[0-8]|
* |q0[6]| | q14[0-8]|
* |q0[7]| | q15[0-8]|
* +--------+ - - - - +---------
*
* Accumulator
*/
static
void
kern_8x8
(
const
int16_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
remain_n
)
{
K
/=
4
;
const
int16_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
LDC
=
LDC
*
sizeof
(
int16_t
);
int
x0
=
0
;
// clang-format off
#define STORE_LINE(reg_index1, reg_index2) \
"cmp %[x0], #0 \n" \
"beq 101f\n" \
"vst1.16 {d" reg_index1 "}, [r0]!\n" \
"vst1.16 {d" reg_index2 "}, [r1]!\n" \
"subs %[x0], %[x0], #1\n"
#define STORE_C \
"mov %[x0], %[remain_n]\n" \
STORE_LINE("16", "17") \
STORE_LINE("18", "19") \
STORE_LINE("20", "21") \
STORE_LINE("22", "23") \
STORE_LINE("24", "25") \
STORE_LINE("26", "27") \
STORE_LINE("28", "29") \
STORE_LINE("30", "31") \
"101:\n"
// clang-format on
register
int16_t
*
outptr
asm
(
"r0"
)
=
output
;
asm
volatile
(
// load accumulator C
"add r1, r0, %[LDC]
\n
"
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.s32 q8 , q8 , q8
\n
"
"veor.s32 q9 , q9 , q9
\n
"
"veor.s32 q10, q10, q10
\n
"
"veor.s32 q11, q11, q11
\n
"
"veor.s32 q12, q12, q12
\n
"
"veor.s32 q13, q13, q13
\n
"
"veor.s32 q14, q14, q14
\n
"
"veor.s32 q15, q15, q15
\n
"
"2:
\n
"
"vld1.8 {d4}, [%[b_ptr]]!
\n
"
"vld1.16 {d0, d1}, [%[a_ptr]]!
\n
"
"vmovl.s8 q2, d4
\n
"
"vld1.16 {d2, d3}, [%[a_ptr]]!
\n
"
"vld1.8 {d6}, [%[b_ptr]]!
\n
"
//! k0
"vmla.s16 q8, q0, d4[0]
\n
"
"vmla.s16 q9, q0, d4[1]
\n
"
"vmla.s16 q10, q0, d4[2]
\n
"
"vmla.s16 q11, q0, d4[3]
\n
"
"vmovl.s8 q3, d6
\n
"
"vmla.s16 q12, q0, d5[0]
\n
"
"vmla.s16 q13, q0, d5[1]
\n
"
"vmla.s16 q14, q0, d5[2]
\n
"
"vmla.s16 q15, q0, d5[3]
\n
"
//! k1
"vld1.16 {d0, d1}, [%[a_ptr]]!
\n
"
"vld1.8 {d4}, [%[b_ptr]]!
\n
"
"vmla.s16 q8, q1, d6[0]
\n
"
"vmla.s16 q9, q1, d6[1]
\n
"
"vmla.s16 q10, q1, d6[2]
\n
"
"vmla.s16 q11, q1, d6[3]
\n
"
"vmovl.s8 q2, d4
\n
"
"vmla.s16 q12, q1, d7[0]
\n
"
"vmla.s16 q13, q1, d7[1]
\n
"
"vmla.s16 q14, q1, d7[2]
\n
"
"vmla.s16 q15, q1, d7[3]
\n
"
//! k2
"vld1.16 {d2, d3}, [%[a_ptr]]!
\n
"
"vld1.8 {d6}, [%[b_ptr]]!
\n
"
"vmla.s16 q8, q0, d4[0]
\n
"
"vmla.s16 q9, q0, d4[1]
\n
"
"vmla.s16 q10, q0, d4[2]
\n
"
"vmla.s16 q11, q0, d4[3]
\n
"
"vmovl.s8 q3, d6
\n
"
"vmla.s16 q12, q0, d5[0]
\n
"
"vmla.s16 q13, q0, d5[1]
\n
"
"vmla.s16 q14, q0, d5[2]
\n
"
"vmla.s16 q15, q0, d5[3]
\n
"
//! k3
"vmla.s16 q8, q1, d6[0]
\n
"
"vmla.s16 q9, q1, d6[1]
\n
"
"vmla.s16 q10, q1, d6[2]
\n
"
"vmla.s16 q11, q1, d6[3]
\n
"
"vmla.s16 q12, q1, d7[0]
\n
"
"vmla.s16 q13, q1, d7[1]
\n
"
"vmla.s16 q14, q1, d7[2]
\n
"
"vmla.s16 q15, q1, d7[3]
\n
"
"subs %[K], %[K], #1
\n
"
"bne 2b
\n
"
"3:
\n
"
"cmp %[remain_n], #8
\n
"
"bne 4f
\n
"
"vstr d16, [r0]
\n
"
"vstr d18, [r0, #8]
\n
"
"vstr d20, [r0, #16]
\n
"
"vstr d22, [r0, #24]
\n
"
"vstr d24, [r0, #32]
\n
"
"vstr d26, [r0, #40]
\n
"
"vstr d28, [r0, #48]
\n
"
"vstr d30, [r0, #56]
\n
"
"vstr d17, [r1]
\n
"
"vstr d19, [r1, #8]
\n
"
"vstr d21, [r1, #16]
\n
"
"vstr d23, [r1, #24]
\n
"
"vstr d25, [r1, #32]
\n
"
"vstr d27, [r1, #40]
\n
"
"vstr d29, [r1, #48]
\n
"
"vstr d31, [r1, #56]
\n
"
"b 101f
\n
"
"4:
\n
"
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
x0
]
"+r"
(
x0
),
[
LDC
]
"+r"
(
LDC
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
outptr
]
"+r"
(
outptr
),
[
remain_n
]
"+r"
(
remain_n
)
:
:
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"d4"
,
"d5"
,
"d6"
,
"d7"
,
"d8"
,
"d9"
,
"d10"
,
"d11"
,
"d12"
,
"d13"
,
"d14"
,
"d15"
,
"d16"
,
"d17"
,
"d18"
,
"d19"
,
"d20"
,
"d21"
,
"d22"
,
"d23"
,
"d24"
,
"d25"
,
"d26"
,
"d27"
,
"d28"
,
"d29"
,
"d30"
,
"d31"
,
"r1"
,
"r2"
,
"r3"
,
"cc"
,
"memory"
);
#undef STORE_C
#undef STORE_LINE
}
/**
* Overview of register layout:
*
* A 8x8x8 cell of Lhs is stored in 16bit in d0, d2
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3
* A 8x8 block of accumulators is stored in 16bit in q8-11
*
* +--------+
* | q4[0-8]|
* Rhs +--------+
* Lhs | |
*
* +--------+ - - - - +---------
* |d0[0]| | q8 [0-8]|
* |d0[1]| | q9 [0-8]|
* |d0[2]| | q10[0-8]|
* |d0[3]| | q11[0-8]|
* +--------+ - - - - +---------
*
* Accumulator
*/
static
void
kern_4x8
(
const
int16_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
remain_n
)
{
K
/=
4
;
const
int16_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
LDC
=
LDC
*
sizeof
(
int16_t
);
int
x0
=
0
;
// clang-format off
#define STORE_LINE(reg_index1) \
"cmp %[x0], #0 \n" \
"beq 101f\n" \
"vst1.16 {d" reg_index1 "}, [r0]!\n" \
"subs %[x0], %[x0], #1\n"
#define STORE_C \
"mov %[x0], %[remain_n]\n" \
STORE_LINE("16") \
STORE_LINE("18") \
STORE_LINE("20") \
STORE_LINE("22") \
STORE_LINE("24") \
STORE_LINE("26") \
STORE_LINE("28") \
STORE_LINE("30") \
"101:\n"
// clang-format on
register
int16_t
*
outptr
asm
(
"r0"
)
=
output
;
asm
volatile
(
//! load accumulator C
"add r1, r0, %[LDC]
\n
"
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.s32 q8 , q8 , q8
\n
"
"veor.s32 q9 , q9 , q9
\n
"
"veor.s32 q10, q10, q10
\n
"
"veor.s32 q11, q11, q11
\n
"
"veor.s32 q12, q12, q12
\n
"
"veor.s32 q13, q13, q13
\n
"
"veor.s32 q14, q14, q14
\n
"
"veor.s32 q15, q15, q15
\n
"
"2:
\n
"
"vld1.8 {d4}, [%[b_ptr]]!
\n
"
"vld1.16 {d0}, [%[a_ptr]]!
\n
"
"vmovl.s8 q2, d4
\n
"
"vld1.16 {d2}, [%[a_ptr]]!
\n
"
"vld1.8 {d6}, [%[b_ptr]]!
\n
"
//! k0
"vmla.s16 d16, d0, d4[0]
\n
"
"vmla.s16 d18, d0, d4[1]
\n
"
"vmla.s16 d20, d0, d4[2]
\n
"
"vmla.s16 d22, d0, d4[3]
\n
"
"vmovl.s8 q3, d6
\n
"
"vmla.s16 d24, d0, d5[0]
\n
"
"vmla.s16 d26, d0, d5[1]
\n
"
"vmla.s16 d28, d0, d5[2]
\n
"
"vmla.s16 d30, d0, d5[3]
\n
"
//! k1
"vld1.16 {d0}, [%[a_ptr]]!
\n
"
"vld1.8 {d4}, [%[b_ptr]]!
\n
"
"vmla.s16 d16, d2, d6[0]
\n
"
"vmla.s16 d18, d2, d6[1]
\n
"
"vmla.s16 d20, d2, d6[2]
\n
"
"vmla.s16 d22, d2, d6[3]
\n
"
"vmovl.s8 q2, d4
\n
"
"vmla.s16 d24, d2, d7[0]
\n
"
"vmla.s16 d26, d2, d7[1]
\n
"
"vmla.s16 d28, d2, d7[2]
\n
"
"vmla.s16 d30, d2, d7[3]
\n
"
//! k2
"vld1.16 {d2}, [%[a_ptr]]!
\n
"
"vld1.8 {d6}, [%[b_ptr]]!
\n
"
"vmla.s16 d16, d0, d4[0]
\n
"
"vmla.s16 d18, d0, d4[1]
\n
"
"vmla.s16 d20, d0, d4[2]
\n
"
"vmla.s16 d22, d0, d4[3]
\n
"
"vmovl.s8 q3, d6
\n
"
"vmla.s16 d24, d0, d5[0]
\n
"
"vmla.s16 d26, d0, d5[1]
\n
"
"vmla.s16 d28, d0, d5[2]
\n
"
"vmla.s16 d30, d0, d5[3]
\n
"
//! k3
"vmla.s16 d16, d2, d6[0]
\n
"
"vmla.s16 d18, d2, d6[1]
\n
"
"vmla.s16 d20, d2, d6[2]
\n
"
"vmla.s16 d22, d2, d6[3]
\n
"
"vmla.s16 d24, d2, d7[0]
\n
"
"vmla.s16 d26, d2, d7[1]
\n
"
"vmla.s16 d28, d2, d7[2]
\n
"
"vmla.s16 d30, d2, d7[3]
\n
"
"subs %[K], %[K], #1
\n
"
"bne 2b
\n
"
"3:
\n
"
"cmp %[remain_n], #8
\n
"
"bne 4f
\n
"
"vstr d16, [r0]
\n
"
"vstr d18, [r0, #8]
\n
"
"vstr d20, [r0, #16]
\n
"
"vstr d22, [r0, #24]
\n
"
"vstr d24, [r0, #32]
\n
"
"vstr d26, [r0, #40]
\n
"
"vstr d28, [r0, #48]
\n
"
"vstr d30, [r0, #56]
\n
"
"b 101f
\n
"
"4:
\n
"
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
x0
]
"+r"
(
x0
),
[
LDC
]
"+r"
(
LDC
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
outptr
]
"+r"
(
outptr
),
[
remain_n
]
"+r"
(
remain_n
)
:
:
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"d4"
,
"d5"
,
"d6"
,
"d7"
,
"d8"
,
"d9"
,
"d10"
,
"d11"
,
"d12"
,
"d13"
,
"d14"
,
"d15"
,
"d16"
,
"d17"
,
"d18"
,
"d19"
,
"d20"
,
"d21"
,
"d22"
,
"d23"
,
"d24"
,
"d25"
,
"d26"
,
"d27"
,
"d28"
,
"d29"
,
"d30"
,
"d31"
,
"r1"
,
"r2"
,
"r3"
,
"cc"
,
"memory"
);
#undef STORE_C
#undef STORE_LINE
}
static
void
gemm_s8x8x16_mk4_8x8_pack_A_n
(
dt_int16
*
outptr
,
const
dt_int8
*
inptr
,
int
ldin
,
int
m0
,
int
mmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
m0
%
4
==
0
&&
mmax
%
4
==
0
,
"M must be time of 4"
);
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
constexpr
int
pack_m
=
8
;
constexpr
int
pack_k
=
4
;
constexpr
int
pack_size
=
4
;
const
int
m_size
=
mmax
-
m0
;
const
int
m_end
=
m_size
/
pack_m
*
pack_m
+
m0
;
const
int
remain_m
=
mmax
-
m_end
;
for
(
int
m_idx
=
m0
;
m_idx
<
m_end
;
m_idx
+=
pack_m
)
{
const
int8_t
*
inptr0
=
inptr
+
m_idx
/
pack_size
*
ldin
+
k0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
for
(
int
k_idx
=
k0
;
k_idx
<
kmax
;
k_idx
+=
pack_size
)
{
interleave_4x4_8x4_s8_s16
(
inptr0
,
inptr1
,
outptr
);
inptr0
+=
pack_size
*
pack_size
;
inptr1
+=
pack_size
*
pack_size
;
outptr
+=
pack_m
*
pack_k
;
}
}
if
(
remain_m
>
0
)
{
const
int8_t
*
inptr0
=
inptr
+
m_end
/
pack_size
*
ldin
+
k0
;
const
int
k_size
=
kmax
-
k0
;
memcpy_s8_s16
(
inptr0
,
outptr
,
k_size
*
pack_size
);
}
}
static
void
gemm_s8x8x16_mk4_8x8_pack_B_n
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
n0
,
int
nmax
,
int
k0
,
int
kmax
)
{
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
int8_t
tmpbuff
[
32
]
=
{
0
};
constexpr
int
pack_n
=
8
;
constexpr
int
pack_size
=
4
;
const
int
ksize
=
kmax
-
k0
;
const
int
nsize
=
nmax
-
n0
;
const
int
n_end
=
nsize
/
pack_n
*
pack_n
+
n0
;
const
int
remain_n
=
nsize
%
pack_n
;
int
output_stride
=
ksize
*
pack_n
;
int8_t
*
outptr_base
=
out
;
for
(
int
k_idx
=
k0
;
k_idx
<
kmax
;
k_idx
+=
pack_size
)
{
const
int8_t
*
inptr
=
in
+
k_idx
/
pack_size
*
ldin
+
n0
*
pack_size
;
prefetch_3x
(
inptr
);
auto
outptr
=
outptr_base
;
for
(
int
n_idx
=
n0
;
n_idx
<
n_end
;
n_idx
+=
pack_n
)
{
transpos_8x4_int8
(
inptr
,
outptr
);
inptr
+=
pack_n
*
pack_size
;
outptr
+=
output_stride
;
}
if
(
remain_n
>
0
)
{
memcpy
(
tmpbuff
,
inptr
,
sizeof
(
int8_t
)
*
remain_n
*
pack_size
);
transpos_8x4_int8
(
tmpbuff
,
outptr
);
outptr
+=
output_stride
;
}
outptr_base
+=
pack_n
*
pack_size
;
}
}
}
// namespace matmul_mk4_8x8x4
}
// namespace armv7
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp
浏览文件 @
eed54081
...
...
@@ -6,14 +6,16 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
...
...
@@ -108,7 +110,7 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_4x
4
==================================
// ===========================gemm_s8x8x16_4x
8
==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s8x8x16_4x8
);
void
gemm_s8x8x16_4x8
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
...
...
@@ -179,4 +181,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_mk4_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s8x8x16_mk4_8x8
);
void
gemm_s8x8x16_mk4_8x8
::
pack_A
(
dt_int16
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
)
const
{
matmul_mk4_8x8x4
::
gemm_s8x8x16_mk4_8x8_pack_A_n
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
void
gemm_s8x8x16_mk4_8x8
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
)
const
{
matmul_mk4_8x8x4
::
gemm_s8x8x16_mk4_8x8_pack_B_n
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
void
gemm_s8x8x16_mk4_8x8
::
kern
(
const
dt_int16
*
packA
,
const
dt_int8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_int16
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_int16
*
,
dt_int16
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
Int16
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Int8
);
megdnn_assert
(
is_first_k
==
true
,
"only impl is_first_k"
);
MEGDNN_MARK_USED_VAR
(
A_dtype
);
MEGDNN_MARK_USED_VAR
(
B_dtype
);
MEGDNN_MARK_USED_VAR
(
C_dtype
);
megdnn_assert
(
M
%
4
==
0
&&
K
%
4
==
0
,
"M and K must be time of 4"
);
constexpr
size_t
pack_size
=
4
;
constexpr
size_t
pack_m
=
8
;
constexpr
size_t
pack_n
=
8
;
const
size_t
remain_n
=
N
%
pack_n
;
const
size_t
remain_m
=
M
%
pack_m
;
size_t
m_idx
=
0
;
for
(;
m_idx
+
pack_m
<=
M
;
m_idx
+=
pack_m
)
{
int16_t
*
output
=
C
+
(
m_idx
/
pack_size
*
LDC
);
size_t
n_idx
=
0
;
const
int8_t
*
cur_packB
=
packB
;
for
(;
n_idx
+
pack_n
<=
N
;
n_idx
+=
pack_n
)
{
matmul_mk4_8x8x4
::
kern_8x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
pack_n
);
output
+=
pack_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
if
(
remain_n
>
0
)
{
matmul_mk4_8x8x4
::
kern_8x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
remain_n
);
output
+=
remain_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
packA
+=
pack_m
*
K
;
}
if
(
remain_m
>
0
)
{
int16_t
*
output
=
C
+
(
m_idx
/
pack_size
*
LDC
);
size_t
n_idx
=
0
;
const
int8_t
*
cur_packB
=
packB
;
for
(;
n_idx
+
pack_n
<=
N
;
n_idx
+=
pack_n
)
{
matmul_mk4_8x8x4
::
kern_4x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
pack_n
);
output
+=
pack_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
if
(
remain_n
>
0
)
{
matmul_mk4_8x8x4
::
kern_4x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
remain_n
);
output
+=
remain_n
*
pack_size
;
cur_packB
+=
pack_n
*
K
;
}
}
}
// vim: syntax=cpp.doxygen
dnn/src/armv7/matrix_mul/int8x8x16/strategy.h
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
...
...
@@ -21,6 +22,10 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true,
MEGDNN_REG_GEMM_STRATEGY
(
int8_t
,
int16_t
,
int16_t
,
4
,
8
,
8
,
false
,
true
,
gemm_s8x8x16_4x8
);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE
(
int8_t
,
int16_t
,
int16_t
,
int16_t
,
8
,
8
,
4
,
false
,
false
,
gemm_s8x8x16_mk4_8x8
);
}
// namespace matmul
}
// namespace armv7
}
// namespace megdnn
...
...
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
eed54081
...
...
@@ -6,10 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/armv7/matrix_mul/opr_impl.h"
#include "src/armv7/matrix_mul/algos.h"
#include "src/armv7/matrix_mul/opr_impl.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
...
...
@@ -21,7 +22,7 @@ using namespace armv7;
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoF32
f32
;
AlgoF32MK4Pack4x12
f32_mk4_pack_4x12
;
AlgoF32MK4_4x8
f32_mk4_4x8
;
AlgoF32MK4_4x8
f32_mk4_4x8
;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K4x16x1
f16_k4x16x1
;
AlgoF16MK8_4x8
f16_mk8_4x8
;
...
...
@@ -38,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K4x8x8
quint8_k4x8x8
;
AlgoInt8x8x16K4x2x16
int8x8x16_k4x2x16
;
AlgoInt8x8x16K4x8x8
int8x8x16_k4x8x8
;
AlgoInt8x8x16MK4_8x8x4
int8x8x16_mk4_8x8x4
;
AlgoInt16x16x32K12x4x1
int16x16x32_k12x4x1
;
AlgoInt16x16x32MK8_4x8
int16x16x32_mk8_4x8
;
...
...
@@ -62,8 +64,10 @@ public:
all_algos
.
emplace_back
(
&
int8x8x32_k4x2x16
);
all_algos
.
emplace_back
(
&
int8x8x32_k4x8x8
);
all_algos
.
emplace_back
(
&
quint8_k4x8x8
);
all_algos
.
emplace_back
(
&
int8x8x16_mk4_8x8x4
);
all_algos
.
emplace_back
(
&
int8x8x16_k4x2x16
);
all_algos
.
emplace_back
(
&
int8x8x16_k4x8x8
);
all_algos
.
emplace_back
(
&
int16x16x32_k12x4x1
);
all_algos
.
emplace_back
(
&
int16x16x32_mk8_4x8
);
}
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
...
...
@@ -19,26 +20,28 @@ public:
using
arm_common
::
MatrixMulImpl
::
MatrixMulImpl
;
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
private:
class
AlgoF32
;
// Armv7 F32
class
AlgoF32MK4Pack4x12
;
// Armv7 F32 Kernel 4x12 with pack
class
AlgoF32MK4_4x8
;
// Armv7 F32 Kernel 4x8 nopack
class
AlgoF32Gemv
;
// Armv7 F32 Gemv
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
class
AlgoInt8x8x32K4x2x16
;
// Armv7 Int8x8x32 Kernel 4x2x16
class
AlgoF32
;
// Armv7 F32
class
AlgoF32MK4Pack4x12
;
// Armv7 F32 Kernel 4x12 with pack
class
AlgoF32MK4_4x8
;
// Armv7 F32 Kernel 4x8 nopack
class
AlgoF32Gemv
;
// Armv7 F32 Gemv
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
class
AlgoInt8x8x32K4x2x16
;
// Armv7 Int8x8x32 Kernel 4x2x16
class
AlgoInt8x8x32MK4_4x2x16
;
// Armv7 Int8x8x32 Kernel MK4 4x2x16
class
AlgoQuint8K4x8x8
;
// Armv7 Quint8 Kernel 4x8x8
class
AlgoInt8x8x16K4x2x16
;
// Armv7 Int8x8x16 Kernel 4x2x16
class
AlgoInt8x8x16K4x8x8
;
// Armv7 Int8x8x16 Kernel 4x8x8
class
AlgoInt16x16x32K12x4x1
;
// Armv7 Int16x16x32 Kernel 12x4x1
class
AlgoInt16x16x32MK8_4x8
;
// Armv7 Int16x16x32 MK8 Format block 4x8
class
AlgoQuint8K4x8x8
;
// Armv7 Quint8 Kernel 4x8x8
class
AlgoInt8x8x16K4x2x16
;
// Armv7 Int8x8x16 Kernel 4x2x16
class
AlgoInt8x8x16K4x8x8
;
// Armv7 Int8x8x16 Kernel 4x8x8
class
AlgoInt8x8x16MK4_8x8x4
;
// Armv7 Int8x8x16 Kernel 8x8x8
class
AlgoInt16x16x32K12x4x1
;
// Armv7 Int16x16x32 Kernel 12x4x1
class
AlgoInt16x16x32MK8_4x8
;
// Armv7 Int16x16x32 MK8 Format block 4x8
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class
AlgoF16K4x16x1
;
// Armv7 F16 Kernel 4x16x1
class
AlgoF16MK8_4x8
;
// Armv7 F16 MK8 Format block 4x8
#endif
#if __ARM_FEATURE_DOTPROD
class
AlgoInt8x8x32K6x8x4
;
// Armv7 Int8 Kernel 6x8x4
class
AlgoQuint8DotK4x8x4
;
// Armv7 Quint8 Kernel 6x8x4
class
AlgoInt8x8x32K6x8x4
;
// Armv7 Int8 Kernel 6x8x4
class
AlgoQuint8DotK4x8x4
;
// Armv7 Quint8 Kernel 6x8x4
class
AlgoInt8x8x32MK4_8x4x4DotProd
;
// Armv7 nchw44 Int8x8x32 Kernel 8x4x4
// DotProduct
#endif
...
...
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
浏览文件 @
eed54081
...
...
@@ -10,9 +10,9 @@
* implied.
*/
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h"
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h"
#include "src/fallback/conv_bias/opr_impl.h"
...
...
@@ -194,10 +194,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
PW
=
param
.
filter_meta
.
padding
[
1
];
size_t
SH
=
param
.
filter_meta
.
stride
[
0
],
SW
=
param
.
filter_meta
.
stride
[
1
];
if
(
FH
!=
1
||
FW
!=
1
||
PH
||
PW
||
SH
!=
1
||
SW
!=
1
)
return
false
;
if
(
param
.
src_type
.
enumv
()
!=
param
.
filter_type
.
enumv
())
{
return
false
;
}
...
...
@@ -216,6 +214,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
bias_mode
!=
megdnn
::
BiasMode
::
NO_BIAS
||
...
...
dnn/src/fallback/conv_bias/im2col/factory.h
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <unordered_map>
...
...
@@ -226,10 +227,10 @@ public:
PostprocessMode
::
FLOAT
,
"DefaultStrategyType::FLOAT"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
auto
matmul_block
=
matmul_algo
->
get_inner_block_size
();
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1
//! im2col+pack fuse
if
((
matmul_block
.
m
==
8
||
matmul_block
.
m
==
4
)
&&
matmul_block
.
n
==
12
&&
matmul_block
.
k
==
1
&&
param
.
filter_meta
.
spatial
[
0
]
==
3
&&
...
...
@@ -297,9 +298,21 @@ public:
break
;
case
StrategyType
::
INT8x8x16
:
cb2
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCESS
,
"DefaultStrategyType::INT8x8x16"
_hash
);
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW
)
{
cb2
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCESS
,
"DefaultStrategyType::INT8x8x16"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
cb2
(
NCHW44
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCESS
,
"DefaultStrategyType::INT8x8x16"
_hash
);
}
else
{
megdnn_throw
(
ssprintf
(
"Current only support layout "
"NCHW44/NCHW for im2col "
"algo, but got %d
\n
"
,
uint32_t
(
format
)));
}
break
;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case
StrategyType
::
QUINT8x8x32
:
...
...
@@ -421,10 +434,11 @@ public:
dt_int32
,
dt_int8
,
PostprocessMode
::
QUANTIZED
,
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"
_hash
);
}
else
{
megdnn_throw
(
ssprintf
(
"Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d
\n
"
,
uint32_t
(
format
)));
megdnn_throw
(
ssprintf
(
"Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d
\n
"
,
uint32_t
(
format
)));
}
break
;
}
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
eed54081
...
...
@@ -6,11 +6,12 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h"
namespace
megdnn
{
namespace
fallback
{
...
...
@@ -66,7 +67,8 @@ public:
};
typedef
void
(
*
kern_t
)(
const
KernParam
&
);
typedef
void
(
*
kern_naked_t
)(
const
KernParam
&
,
const
void
*
a_panel
,
const
void
*
b_panel
);
typedef
void
(
*
kern_naked_t
)(
const
KernParam
&
,
const
void
*
a_panel
,
const
void
*
b_panel
);
class
AlgoBase
:
public
Algorithm
{
protected:
virtual
~
AlgoBase
()
=
default
;
...
...
@@ -83,18 +85,19 @@ public:
bool
can_be_treated_as_int8x8x16
(
const
KernSizeParam
&
param
)
const
{
return
param
.
A_type
.
enumv
()
==
param
.
B_type
.
enumv
()
&&
param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
param
.
C_type
.
enumv
()
==
DTypeEnum
::
Int16
&&
param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
;
(
param
.
A_type
.
enumv
()
==
DTypeEnum
::
Int8
||
param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
&&
(
param
.
C_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
)
;
}
public:
enum
class
AlgoSet
:
uint32_t
{
enum
class
AlgoSet
:
uint32_t
{
ALGO_TYPE_GEMM
=
0
,
ALGO_TYPE_GEMV
=
1
,
};
enum
class
PackMode
:
uint32_t
{
enum
class
PackMode
:
uint32_t
{
DEFAULT
=
0
,
NO_PACK
=
1
,
ONLY_PACKA
=
2
,
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
eed54081
...
...
@@ -489,25 +489,26 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle,
void
BENCHMARK_IM2COL_NCHW44_VS_NCHW
(
const
char
*
algo_name
,
const
char
*
im2col_name
,
Handle
*
handle
,
size_t
kernel
,
size_t
pack_size
=
1
)
{
auto
&&
args
=
get_winograd_benchmark_args
(
kernel
,
pack_size
);
size_t
kernel
,
DType
src_type
,
DType
dst_type
)
{
auto
&&
args
=
get_winograd_benchmark_args
(
kernel
,
4
);
using
namespace
conv_bias
;
constexpr
size_t
RUN
=
10
;
Benchmarker
<
ConvBias
>
benchmark
(
handle
);
benchmark
.
set_display
(
false
);
benchmark
.
set_times
(
RUN
);
benchmark
.
set_dtype
(
0
,
dtype
::
Int8
()
);
benchmark
.
set_dtype
(
1
,
dtype
::
Int8
()
);
benchmark
.
set_dtype
(
2
,
d
type
::
Int32
()
);
benchmark
.
set_dtype
(
4
,
d
type
::
Int32
()
);
benchmark
.
set_dtype
(
0
,
src_type
);
benchmark
.
set_dtype
(
1
,
src_type
);
benchmark
.
set_dtype
(
2
,
d
st_type
);
benchmark
.
set_dtype
(
4
,
d
st_type
);
Benchmarker
<
ConvBias
>
benchmark_im2col
(
handle
);
benchmark_im2col
.
set_display
(
false
);
benchmark_im2col
.
set_times
(
RUN
);
benchmark_im2col
.
set_dtype
(
0
,
dtype
::
Int8
()
);
benchmark_im2col
.
set_dtype
(
1
,
dtype
::
Int8
()
);
benchmark_im2col
.
set_dtype
(
2
,
d
type
::
Int32
()
);
benchmark_im2col
.
set_dtype
(
4
,
d
type
::
Int32
()
);
benchmark_im2col
.
set_dtype
(
0
,
src_type
);
benchmark_im2col
.
set_dtype
(
1
,
src_type
);
benchmark_im2col
.
set_dtype
(
2
,
d
st_type
);
benchmark_im2col
.
set_dtype
(
4
,
d
st_type
);
for
(
auto
&&
arg
:
args
)
{
TensorLayout
dst_layout
;
...
...
@@ -556,6 +557,7 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name,
computations
/
used_im2col
,
used
/
used_im2col
);
}
}
#if MEGDNN_AARCH64
TEST_F
(
ARM_COMMON
,
BENCHMARK_NCHW_VS_NCHW44_INT8x8x32
)
{
printf
(
"=========================compare "
...
...
@@ -563,7 +565,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16
\n
"
);
BENCHMARK_IM2COL_NCHW44_VS_NCHW
(
"IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"
,
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16"
,
handle
(),
3
,
4
);
handle
(),
3
,
dtype
::
Int8
(),
dtype
::
Int32
());
}
#endif
#if MEGDNN_ARMV7
TEST_F
(
ARM_COMMON
,
BENCHMARK_NCHW_VS_NCHW44_INT8x8x16
)
{
const
char
*
default_algo
=
"IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"
;
const
char
*
mk4_algo
=
"IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"
;
printf
(
"compare %s vs %s
\n
"
,
default_algo
,
mk4_algo
);
BENCHMARK_IM2COL_NCHW44_VS_NCHW
(
default_algo
,
mk4_algo
,
handle
(),
3
,
dtype
::
Int8
(),
dtype
::
Int16
());
}
#endif
...
...
@@ -1860,15 +1872,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
//! channel bias
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
/
4
,
h
,
w
,
4
},
TensorShape
{
oc
/
4
,
ic
/
4
,
kernel
,
kernel
,
4
,
4
},
TensorShape
{
1
,
oc
/
4
,
1
,
1
,
4
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
/
4
,
h
,
w
,
4
},
TensorShape
{
oc
/
4
,
ic
/
4
,
kernel
,
kernel
,
4
,
4
},
TensorShape
{
1
,
oc
/
4
,
1
,
1
,
4
});
};
for
(
size_t
stride
:
{
1
,
2
})
for
(
size_t
kernel
:
{
2
,
3
,
5
,
7
})
for
(
size_t
oc
:
{
64
})
for
(
size_t
oc
:
{
64
})
for
(
NonlineMode
nonline_mode
:
{
NonlineMode
::
IDENTITY
})
{
run
(
oc
,
oc
,
56
,
56
,
kernel
,
kernel
/
2
,
stride
,
nonline_mode
);
run
(
oc
,
oc
,
56
,
56
,
kernel
,
kernel
/
2
,
stride
,
nonline_mode
);
}
constexpr
size_t
RUN
=
50
;
...
...
@@ -1880,7 +1893,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
benchmark0
.
set_display
(
false
);
benchmark0
.
set_times
(
RUN
);
benchmark0
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
"ARMDOTS8DIRECT_NCHW44"
));
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
"ARMDOTS8DIRECT_NCHW44"
));
Benchmarker
<
ConvBias
>
benchmark1
(
handle
());
benchmark1
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
...
...
@@ -2002,15 +2016,20 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(
void
benchmark_conv1x1
(
const
char
*
matmul_algo_name
,
Handle
*
handle
,
DType
stype
,
DType
matmul_dtype
,
DType
bias_type
,
DType
conv_dtype
)
{
DType
conv_dtype
,
bool
is_mk4
=
false
)
{
using
namespace
conv_bias
;
int
pack_size
=
is_mk4
?
4
:
1
;
std
::
vector
<
TestArg
>
conv_bias_1x1_args
=
get_conv_bias_1x1_benchmark_args
();
get_conv_bias_1x1_benchmark_args
(
pack_size
);
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
param
.
transposeA
=
false
;
param
.
transposeB
=
false
;
if
(
is_mk4
)
{
param
.
format
=
MatrixMul
::
Param
::
Format
::
MK4
;
}
Benchmarker
<
MatrixMul
>
benchmark_matmul
(
handle
);
benchmark_matmul
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
matmul_algo_name
));
...
...
@@ -2038,8 +2057,8 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
size_t
OH
=
arg
.
src
[
2
];
size_t
OW
=
arg
.
src
[
3
];
size_t
OC
=
arg
.
filter
[
0
];
size_t
M
=
OC
;
size_t
K
=
IC
;
size_t
M
=
OC
*
pack_size
;
size_t
K
=
IC
*
pack_size
;
size_t
N
=
OH
*
OW
;
float
computations
=
M
*
N
*
K
*
2.
f
/
(
1024
*
1024
*
1024
)
*
1e3
;
...
...
@@ -2047,6 +2066,10 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle,
TensorShape
A
,
B
;
A
=
TensorShape
{
M
,
K
};
B
=
TensorShape
{
K
,
N
};
if
(
is_mk4
)
{
A
=
TensorShape
{
M
/
4
,
K
/
4
,
4
,
4
};
B
=
TensorShape
{
K
/
4
,
N
,
4
};
}
auto
conv1x1_used
=
benchmark_conv1x1
.
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}})
/
...
...
@@ -2133,6 +2156,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) {
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int16
{});
benchmark_conv1x1
(
"ARMV7_INT8X8X16_K4X2X16"
,
handle
(),
dtype
::
Int8
{},
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int16
{});
benchmark_conv1x1
(
"ARMV7_INT8X8X16_MK4_K8X8X4"
,
handle
(),
dtype
::
Int8
{},
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int16
{},
true
);
#endif
}
...
...
@@ -2145,13 +2170,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) {
conv_param
.
pad_h
=
0
;
conv_param
.
pad_w
=
0
;
conv_param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
IDENTITY
;
auto
run
=
[
&
](
size_t
M
,
size_t
K
){
auto
run
=
[
&
](
size_t
M
,
size_t
K
)
{
args
.
emplace_back
(
conv_param
,
TensorShape
{
1
,
K
,
1
,
1
},
TensorShape
{
M
,
K
,
1
,
1
},
TensorShape
{});
};
for
(
size_t
M
:
{
4
,
64
,
1024
,
4096
})
for
(
size_t
K
:
{
128
,
256
,
1024
,
4096
})
run
(
M
,
K
);
run
(
M
,
K
);
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
eed54081
...
...
@@ -850,7 +850,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
param
::
ConvBias
::
Format
::
NCHW44
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS
)
{
using
namespace
conv_bias
;
std
::
vector
<
TestArg
>
args
=
get_nchw44_conv_bias_args
({
3
},
1
);
Checker
<
ConvBiasForward
,
OprWeightPreprocessProxy
<
ConvBiasForward
>>
checker
(
...
...
@@ -1131,7 +1132,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
1e-3
f
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS
)
{
using
namespace
conv_bias
;
Checker
<
ConvBiasForward
,
OprWeightPreprocessProxy
<
ConvBiasForward
>>
checker
(
...
...
@@ -2089,6 +2091,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_IM2COLMATMUL_INT8x8x16
)
{
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44
=
get_nchw44_conv_bias_args
({
2
,
3
,
4
,
5
,
6
,
7
},
1
,
true
,
true
,
true
,
false
,
false
,
false
,
false
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44_1x1s2
=
get_nchw44_conv_bias_args
({
1
},
2
,
true
,
true
,
true
,
false
,
false
,
false
,
false
,
true
);
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
...
...
@@ -2098,6 +2106,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); \
checker_conv_bias(args_nchw44_1x1s2, handle(), &rng, epsilon, \
dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \
dtype::Int16{}, name);
#if MEGDNN_AARCH64
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"
);
cb
(
"IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"
);
...
...
@@ -2106,8 +2121,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
cb
(
"IM2COLMATMUL:ARM_COMMON_INT8X8X16"
);
cb
(
"IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"
);
cb
(
"IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"
);
cb_nchw44
(
"IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"
);
#endif
#undef cb
#undef cb_nchw44
}
#endif
...
...
@@ -2516,19 +2534,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
true
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44
=
get_nchw44_conv_bias_args
(
{
1
},
1
,
true
,
true
,
true
,
false
,
false
,
false
,
false
,
true
);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
#if MEGDNN_AARCH64
cb
(
"CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"
);
cb
(
"CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"
);
#elif MEGDNN_ARMV7
cb
(
"CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"
);
cb
(
"CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"
);
cb_nchw44
(
"CONV1x1:ARMV7_INT8X8X16_MK4_K8X8X4:48"
);
#endif
cb
(
"CONV1x1:ARM_COMMON_INT8X8X16:48"
);
#undef cb
#undef cb_nchw44
std
::
vector
<
conv_bias
::
TestArg
>
gemv_args
;
for
(
auto
&&
arg
:
args
)
...
...
dnn/test/armv7/matrix_mul.cpp
浏览文件 @
eed54081
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/armv7/fixture.h"
#include "test/common/benchmarker.h"
...
...
@@ -51,9 +52,15 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
handle
(),
"ARMV7_INT8X8X16_K4X8X8"
);
}
TEST_F
(
ARMV7
,
MATRIX_MUL_INT8x8x16_MK4_K8x8x4
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int8
{},
dtype
::
Int8
{},
dtype
::
Int16
{},
handle
(),
"ARMV7_INT8X8X16_MK4_K8X8X4"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
}
TEST_F
(
ARMV7
,
MATRIX_MUL_INT16x16x32
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int32
{},
handle
(),
"ARMV7_INT16X16X32_K12X4X1"
);
matrix_mul
::
check_matrix_mul
(
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int32
{},
handle
(),
"ARMV7_INT16X16X32_K12X4X1"
);
}
TEST_F
(
ARMV7
,
MATRIX_MUL_INT16x16x32_MK8
)
{
...
...
@@ -83,7 +90,8 @@ TEST_F(ARMV7, MATRIX_MUL_SDOT) {
TEST_F
(
ARMV7
,
MATRIX_MUL_UDOT
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Quantized8Asymm
(
4.0
f
,
static_cast
<
uint8_t
>
(
10
)),
dtype
::
Quantized8Asymm
(
3.0
f
,
static_cast
<
uint8_t
>
(
54
)),
dtype
::
Quantized8Asymm
(
4.0
f
,
static_cast
<
uint8_t
>
(
10
)),
dtype
::
Quantized8Asymm
(
3.0
f
,
static_cast
<
uint8_t
>
(
54
)),
dtype
::
QuantizedS32
(
12.0
f
),
handle
(),
"AARCH32_QUINT8_K4X8X4"
);
}
...
...
@@ -103,7 +111,9 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) {
#if MEGDNN_WITH_BENCHMARK
namespace
{
void
run_8x8x16_benchmark
(
const
char
*
algo
,
Handle
*
handle
)
{
void
run_8x8x16_benchmark
(
const
char
*
algo
,
Handle
*
handle
,
MatrixMul
::
Param
::
Format
format
=
MatrixMul
::
Param
::
Format
::
DEFAULT
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
Benchmarker
<
MatrixMul
>
benchmarker_int
(
handle
);
...
...
@@ -116,21 +126,31 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) {
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_display
(
false
);
param
::
MatrixMul
target_param
;
target_param
.
format
=
format
;
benchmarker_int_kern_4x2x16
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
algo
));
benchmarker_int_kern_4x2x16
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_param
(
target_
param
)
.
set_display
(
false
);
Benchmarker
<
MatrixMul
>
benchmarker_float
(
handle
);
benchmarker_float
.
set_display
(
false
).
set_times
(
RUNS
);
auto
run
=
[
&
](
size_t
M
,
size_t
N
,
size_t
K
)
{
auto
int_used
=
benchmarker_int
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
int_kern_used
=
benchmarker_int_kern_4x2x16
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
int_kern_used
=
1e10
;
if
(
format
==
MatrixMul
::
Param
::
Format
::
MK4
)
{
int_kern_used
=
benchmarker_int_kern_4x2x16
.
exec
(
{{
M
/
4
,
K
/
4
,
4
,
4
},
{
K
/
4
,
N
,
4
},
{}})
/
RUNS
;
}
else
{
int_kern_used
=
benchmarker_int_kern_4x2x16
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
}
auto
float_used
=
benchmarker_float
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
float
computations
=
2.
f
*
M
*
K
*
N
*
1e-6
;
printf
(
"run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f "
...
...
@@ -145,6 +165,7 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) {
};
run
(
256
,
12
*
24
,
256
);
run
(
256
,
256
,
256
);
//////////////////////// gemv //////////////////////////
for
(
size_t
M
:
{
8
,
64
,
112
,
256
})
{
...
...
@@ -185,7 +206,8 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) {
"int: %f ms %f Gflops %s:
\n
"
"speedup(%s/arm_common, %s/float): %f
\n
"
,
M
,
K
,
N
,
float_used
,
computations
/
float_used
,
int_used
,
computations
/
int_used
,
algo
,
algo
,
algo
,
float_used
/
int_used
);
computations
/
int_used
,
algo
,
algo
,
algo
,
float_used
/
int_used
);
};
run
(
256
,
12
*
24
,
256
);
...
...
@@ -231,7 +253,8 @@ void run_8x8x32_benchmark(const char* algo, Handle* handle) {
"int: %f ms %f Gflops %s:
\n
"
"speedup(%s/arm_common, %s/float): %f
\n
"
,
M
,
K
,
N
,
float_used
,
computations
/
float_used
,
int_used
,
computations
/
int_used
,
algo
,
algo
,
algo
,
float_used
/
int_used
);
computations
/
int_used
,
algo
,
algo
,
algo
,
float_used
/
int_used
);
};
run
(
256
,
12
*
24
,
256
);
...
...
@@ -252,9 +275,11 @@ void run_8x8x32_quint_benchmark(Handle* handle) {
benchmarker_quint8_dot
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH32_QUINT8_K4X8X4"
));
benchmarker_quint8_dot
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
2.3
f
,
static_cast
<
uint8_t
>
(
20
)))
.
set_dtype
(
1
,
dtype
::
Quantized8Asymm
(
3.1
f
,
static_cast
<
uint8_t
>
(
30
)))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
2.3
f
*
3.1
f
))
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
2.3
f
,
static_cast
<
uint8_t
>
(
20
)))
.
set_dtype
(
1
,
dtype
::
Quantized8Asymm
(
3.1
f
,
static_cast
<
uint8_t
>
(
30
)))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
2.3
f
*
3.1
f
))
.
set_param
(
param
)
.
set_display
(
false
);
...
...
@@ -262,14 +287,17 @@ void run_8x8x32_quint_benchmark(Handle* handle) {
benchmarker_quint8
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"ARMV7_QUINT8_K4X8X8"
));
benchmarker_quint8
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
2.3
f
,
static_cast
<
uint8_t
>
(
20
)))
.
set_dtype
(
1
,
dtype
::
Quantized8Asymm
(
3.1
f
,
static_cast
<
uint8_t
>
(
30
)))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
2.3
f
*
3.1
f
))
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
2.3
f
,
static_cast
<
uint8_t
>
(
20
)))
.
set_dtype
(
1
,
dtype
::
Quantized8Asymm
(
3.1
f
,
static_cast
<
uint8_t
>
(
30
)))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
2.3
f
*
3.1
f
))
.
set_param
(
param
)
.
set_display
(
false
);
auto
run
=
[
&
](
size_t
M
,
size_t
N
,
size_t
K
)
{
auto
dot_used
=
benchmarker_quint8_dot
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
dot_used
=
benchmarker_quint8_dot
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
normal_used
=
benchmarker_quint8
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
float
computations
=
2.
f
*
M
*
K
*
N
*
1e-6
;
printf
(
"run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops
\n
"
...
...
@@ -351,11 +379,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) {
run_8x8x16_benchmark
(
"ARMV7_INT8X8X16_K4X2X16"
,
handle
());
}
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8
)
{
run_8x8x16_benchmark
(
"ARMV7_INT8X8X16_K4X8X8"
,
handle
());
}
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8
)
{
run_8x8x16_benchmark
(
"ARMV7_INT8X8X16_MK4_K8X8X4"
,
handle
(),
MatrixMul
::
Param
::
Format
::
MK4
);
}
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1
)
{
run_16x16x32_benchmark
(
"ARMV7_INT16X16X32_K12X4X1"
,
handle
());
}
...
...
dnn/test/common/matrix_mul.cpp
浏览文件 @
eed54081
...
...
@@ -6,12 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/matrix_mul.h"
#include "src/common/utils.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/matrix_mul.h"
using
namespace
megdnn
;
using
namespace
test
;
...
...
@@ -39,9 +40,9 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() {
std
::
vector
<
matrix_mul
::
TestArg
>
matrix_mul
::
get_matmul_mk_packed_args
(
size_t
nbase
)
{
std
::
vector
<
TestArg
>
args
;
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
})
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
11
})
for
(
size_t
n
:
{
1
,
2
,
3
,
4
,
5
,
8
,
12
,
16
,
24
})
for
(
size_t
k
:
{
1
,
2
,
3
,
4
,
5
,
9
,
10
})
for
(
size_t
k
:
{
1
,
2
,
3
,
4
,
5
,
9
,
10
,
11
})
args
.
emplace_back
(
m
,
n
*
nbase
,
k
,
0
);
return
args
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录