Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
27ef788f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
27ef788f
编写于
5月 17, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/armv7): add armv7 mk4 matmul
GitOrigin-RevId: 8ef24bf53b19f7863b8a51da5a81135c941d1a72
上级
efb60be2
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
620 addition
and
7 deletion
+620
-7
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
+1
-1
dnn/src/armv7/matrix_mul/algos.cpp
dnn/src/armv7/matrix_mul/algos.cpp
+68
-0
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+11
-0
dnn/src/armv7/matrix_mul/asm/common.h
dnn/src/armv7/matrix_mul/asm/common.h
+56
-0
dnn/src/armv7/matrix_mul/fp32/strategy.h
dnn/src/armv7/matrix_mul/fp32/strategy.h
+3
-0
dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp
dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp
+451
-0
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+1
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+12
-6
dnn/test/armv7/matrix_mul.cpp
dnn/test/armv7/matrix_mul.cpp
+15
-0
未找到文件。
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
浏览文件 @
27ef788f
...
@@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
...
@@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
"cmp %w[n_remain], #3\n" \
"cmp %w[n_remain], #3\n" \
"blt 22f\n" \
"blt 22f\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \
"b 2
3
f\n" \
"b 2
4
f\n" \
"22:\n" \
"22:\n" \
"cmp %w[n_remain], #2\n" \
"cmp %w[n_remain], #2\n" \
"blt 23f\n" \
"blt 23f\n" \
...
...
dnn/src/armv7/matrix_mul/algos.cpp
浏览文件 @
27ef788f
...
@@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern,
...
@@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern,
"AlgoF32Impl"
_hash
,
"AlgoF32Impl"
_hash
,
armv7
::
matmul
::
sgemm_4x12
,
float
,
float
);
armv7
::
matmul
::
sgemm_4x12
,
float
,
float
);
/* ===================== F32 algo mk4 K4x12 ===================== */
namespace
{
void
f32_mk4_pack_4x12_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"f32_mk4_pack_4x12_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
float
>
(),
Bptr
=
kern_param
.
B
<
float
>
();
auto
Cptr
=
kern_param
.
C
<
float
>
();
armv7
::
matmul
::
sgemm_mk4_pack_4x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
sgemm_mk4_pack_4x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoF32MK4Pack4x12
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK4
&&
kern_size_param
.
B_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
C_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
A_type
==
dtype
::
Float32
()
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
M
%
4
==
0
&&
kern_size_param
.
K
%
4
==
0
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
;
}
size_t
MatrixMulImpl
::
AlgoF32MK4Pack4x12
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_armv7_matmul_kern
,
midout_iv
(
"AlgoF32MK4Pack4x12::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
armv7
::
matmul
::
sgemm_mk4_pack_4x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
armv7
::
matmul
::
sgemm_mk4_pack_4x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF32MK4Pack4x12
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
f32_mk4_pack_4x12_kern
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32MK4Pack4x12
,
megdnn_armv7_matmul_kern
,
"AlgoF32MK4Pack4x12"
_hash
,
armv7
::
matmul
::
sgemm_mk4_pack_4x12
,
float
,
float
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 K4x16x1 algo ===================== */
/* ===================== F16 K4x16x1 algo ===================== */
namespace
{
namespace
{
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
27ef788f
...
@@ -29,6 +29,17 @@ public:
...
@@ -29,6 +29,17 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
};
class
MatrixMulImpl
::
AlgoF32MK4Pack4x12
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARMV7_F32_MK4_PACK_4X12"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
class
MatrixMulImpl
::
AlgoF32MK4_4x8
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoF32MK4_4x8
final
:
public
AlgoBase
{
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/armv7/matrix_mul/asm/common.h
浏览文件 @
27ef788f
...
@@ -1120,6 +1120,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1,
...
@@ -1120,6 +1120,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1,
:
"q0"
,
"q1"
,
"q2"
,
"memory"
);
:
"q0"
,
"q1"
,
"q2"
,
"memory"
);
}
}
template
<
typename
T
>
static
inline
void
transpose_1x12_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_1x12_4_s only support sizeof(T) == 4"
);
asm
volatile
(
"vld4.32 {d0-d3}, [%[inptr0]]!
\n
"
"vld4.32 {d4-d7}, [%[inptr0]]!
\n
"
"vld4.32 {d8-d11}, [%[inptr0]]!
\n
"
"vld4.32 {d12-d15}, [%[inptr0]]!
\n
"
"vld4.32 {d16-d19}, [%[inptr0]]!
\n
"
"vld4.32 {d20-d23}, [%[inptr0]]!
\n
"
"vswp d1, d4
\n
"
"vswp d3, d6
\n
"
"vswp d9, d12
\n
"
"vswp d11, d14
\n
"
"vswp d17, d20
\n
"
"vswp d19, d22
\n
"
"vst1.32 {d0-d1}, [%[outptr]]!
\n
"
"vst1.32 {d8-d9}, [%[outptr]]!
\n
"
"vst1.32 {d16-d17}, [%[outptr]]!
\n
"
"vst1.32 {d4-d5}, [%[outptr]]!
\n
"
"vst1.32 {d12-d13}, [%[outptr]]!
\n
"
"vst1.32 {d20-d21}, [%[outptr]]!
\n
"
"vst1.32 {d2-d3}, [%[outptr]]!
\n
"
"vst1.32 {d10-d11}, [%[outptr]]!
\n
"
"vst1.32 {d18-d19}, [%[outptr]]!
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
"vst1.32 {d14-d15}, [%[outptr]]!
\n
"
"vst1.32 {d22-d23}, [%[outptr]]!
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_1x4_4_s
(
const
T
*&
inptr0
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_1x4_4_s only support sizeof(T) == 4"
);
asm
volatile
(
"vld4.32 {d0-d3}, [%[inptr0]]!
\n
"
"vld4.32 {d4-d7}, [%[inptr0]]!
\n
"
"vswp d1, d4
\n
"
"vswp d3, d6
\n
"
"vst1.32 {d0-d1}, [%[outptr]]!
\n
"
"vst1.32 {d4-d5}, [%[outptr]]!
\n
"
"vst1.32 {d2-d3}, [%[outptr]]!
\n
"
"vst1.32 {d6-d7}, [%[outptr]]!
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
outptr
]
"+r"
(
outptr
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"memory"
);
}
template
<
typename
T
>
template
<
typename
T
>
static
inline
void
transpose_4
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
static
inline
void
transpose_4
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*
outptr
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*
outptr
,
...
...
dnn/src/armv7/matrix_mul/fp32/strategy.h
浏览文件 @
27ef788f
...
@@ -18,6 +18,9 @@ namespace matmul {
...
@@ -18,6 +18,9 @@ namespace matmul {
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
true
,
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
true
,
sgemm_4x12
);
sgemm_4x12
);
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
false
,
sgemm_mk4_pack_4x12
);
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
4
,
8
,
1
,
false
,
true
,
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
4
,
8
,
1
,
false
,
true
,
sgemm_nopack_4x8
);
sgemm_nopack_4x8
);
...
...
dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp
0 → 100644
浏览文件 @
27ef788f
/**
* \file dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/armv7/matrix_mul/fp32/strategy.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
using
namespace
megdnn
;
using
namespace
armv7
;
using
namespace
armv7
::
matmul
;
namespace
{
// Overview of register layout:
//
// A 1x12 cell of Rhs is stored in 32bit in q1-q3
// A 4x1 cell of Lhs is stored in 132bit in q0
// A 4x12 block of accumulators is stored in 32bit in q4-q15.
//
// +--------+--------+--------+
// | q1[0-3]| q2[0-3]| q3[0-3]|
// Rhs +--------+--------+--------+
//
// | | | |
//
// Lhs | | | |
//
// +--+ - - - - +--------+--------+--------+
// |q0| | q4[0-3]| q5[0-3]| q6[0-3]|
// |q0| | q7[0-3]| q8[0-3]| q9[0-3]|
// |q0| |q10[0-3]|q11[0-3]|q12[0-3]|
// |q0| |q13[0-3]|q14[0-3]|q15[0-3]|
// +--+ - - - - +--------+--------+--------+
//
// Accumulator
void
kern_4x12
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
)
{
MEGDNN_MARK_USED_VAR
(
LDC
);
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
float
*
output0
=
output
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
asm
volatile
(
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"mov r1, %[output0]
\n
"
"vld1.32 {d8-d11}, [r1]!
\n
"
"vld1.32 {d12-d15}, [r1]!
\n
"
"vld1.32 {d16-d19}, [r1]!
\n
"
"vld1.32 {d20-d23}, [r1]!
\n
"
"vld1.32 {d24-d27}, [r1]!
\n
"
"vld1.32 {d28-d31}, [r1]!
\n
"
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.32 q4, q4, q4
\n
"
"pld [%[output0]]
\n
"
"veor.32 q5, q4, q4
\n
"
"veor.32 q6, q4, q4
\n
"
"veor.32 q7, q4, q4
\n
"
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"veor.32 q8, q4, q4
\n
"
"veor.32 q9, q4, q4
\n
"
"veor.32 q10, q4, q4
\n
"
"veor.32 q11, q4, q4
\n
"
"vld1.32 {d4-d7}, [%[b_ptr]]!
\n
"
"veor.32 q12, q4, q4
\n
"
"veor.32 q13, q4, q4
\n
"
"veor.32 q14, q4, q4
\n
"
"veor.32 q15, q4, q4
\n
"
"2:
\n
"
"cmp %[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"vmla.f32 q4, q0, d4[0]
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q8, q0, d6[0]
\n
"
"vmla.f32 q9, q0, d6[1]
\n
"
"vmla.f32 q10, q0, d7[0]
\n
"
"vld1.32 {d2-d3}, [%[a_ptr]]!
\n
"
"vmla.f32 q11, q0, d7[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q12, q0, d4[0]
\n
"
"vmla.f32 q13, q0, d4[1]
\n
"
"vmla.f32 q14, q0, d5[0]
\n
"
"vmla.f32 q15, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q4, q1, d6[0]
\n
"
"subs %[K], %[K], #1
\n
"
"vmla.f32 q5, q1, d6[1]
\n
"
"vmla.f32 q6, q1, d7[0]
\n
"
"vmla.f32 q7, q1, d7[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q8, q1, d4[0]
\n
"
"vmla.f32 q9, q1, d4[1]
\n
"
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"vmla.f32 q10, q1, d5[0]
\n
"
"vmla.f32 q11, q1, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q12, q1, d6[0]
\n
"
"vmla.f32 q13, q1, d6[1]
\n
"
"vmla.f32 q14, q1, d7[0]
\n
"
"vmla.f32 q15, q1, d7[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"vmla.f32 q4, q0, d4[0]
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q8, q0, d6[0]
\n
"
"vmla.f32 q9, q0, d6[1]
\n
"
"vmla.f32 q10, q0, d7[0]
\n
"
"vld1.32 {d2-d3}, [%[a_ptr]]!
\n
"
"vmla.f32 q11, q0, d7[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q12, q0, d4[0]
\n
"
"vmla.f32 q13, q0, d4[1]
\n
"
"vmla.f32 q14, q0, d5[0]
\n
"
"vmla.f32 q15, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q4, q1, d6[0]
\n
"
"subs %[K], %[K], #1
\n
"
"vmla.f32 q5, q1, d6[1]
\n
"
"vmla.f32 q6, q1, d7[0]
\n
"
"vmla.f32 q7, q1, d7[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q8, q1, d4[0]
\n
"
"vmla.f32 q9, q1, d4[1]
\n
"
"vst1.32 {d8-d11}, [%[output0]]!
\n
"
"vmla.f32 q10, q1, d5[0]
\n
"
"vmla.f32 q11, q1, d5[1]
\n
"
"vst1.32 {d12-d15}, [%[output0]]!
\n
"
"vmla.f32 q12, q1, d6[0]
\n
"
"vmla.f32 q13, q1, d6[1]
\n
"
"vst1.32 {d16-d19}, [%[output0]]!
\n
"
"vmla.f32 q14, q1, d7[0]
\n
"
"vmla.f32 q15, q1, d7[1]
\n
"
"vst1.32 {d20-d23}, [%[output0]]!
\n
"
"vst1.32 {d24-d27}, [%[output0]]!
\n
"
"vst1.32 {d28-d31}, [%[output0]]!
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"vmla.f32 q4, q0, d4[0]
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q8, q0, d6[0]
\n
"
"vst1.32 {d8-d11}, [%[output0]]!
\n
"
"vmla.f32 q9, q0, d6[1]
\n
"
"vmla.f32 q10, q0, d7[0]
\n
"
"vst1.32 {d12-d15}, [%[output0]]!
\n
"
"vmla.f32 q11, q0, d7[1]
\n
"
"vmla.f32 q12, q0, d4[0]
\n
"
"vst1.32 {d16-d19}, [%[output0]]!
\n
"
"vmla.f32 q13, q0, d4[1]
\n
"
"vst1.32 {d20-d23}, [%[output0]]!
\n
"
"vmla.f32 q14, q0, d5[0]
\n
"
"vst1.32 {d24-d27}, [%[output0]]!
\n
"
"vmla.f32 q15, q0, d5[1]
\n
"
"vst1.32 {d28-d31}, [%[output0]]!
\n
"
"6:
\n
"
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output0
]
"+r"
(
output0
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r1"
,
"cc"
,
"memory"
);
}
// Overview of register layout:
//
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1
// A 4x4 block of accumulators is stored in 32bit in v4-v6
//
// +--------+
// | q2[0-3]|
// | q5[0-3]|
// Rhs +--------+
//
// | |
//
// Lhs | |
//
// +--+ --- - +--------+
// |q0| | q8[0-3]|
// |q0| |q11[0-3]|
// |q0| |q14[0-3]|
// |q0| |q17[0-3]|
// +--+ --- - +--------+
//
// Accumulator
void
kern_4x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
n_remain
)
{
MEGDNN_MARK_USED_VAR
(
LDC
);
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
//clang-format off
#define LOAD_C \
"cmp %[n_remain], #4\n" \
"blt 11f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d15}, [r1]!\n" \
"b 14f\n" \
"11:\n" \
"cmp %[n_remain], #3\n" \
"blt 12f\n" \
"vld1.32 {d8-d11}, [r1]!\n" \
"vld1.32 {d12-d13}, [r1]!\n" \
"b 14f\n" \
"12:\n" \
"cmp %[n_remain], #2\n" \
"blt 13f\n" \
"vld1.32 {d8-d11}, [r1]\n" \
"b 14f\n" \
"13:\n" \
"vld1.32 {d8-d9}, [r1]\n" \
"14:\n"
#define STORE_C \
"cmp %[n_remain], #4\n" \
"blt 21f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d15}, [%[output]]!\n" \
"b 24f\n" \
"21:\n" \
"cmp %[n_remain], #3\n" \
"blt 22f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"vst1.32 {d12-d13}, [%[output]]!\n" \
"b 24f\n" \
"22:\n" \
"cmp %[n_remain], #2\n" \
"blt 23f\n" \
"vst1.32 {d8-d11}, [%[output]]!\n" \
"b 24f\n" \
"23:\n" \
"vst1.32 {d8-d9}, [%[output]]!\n" \
"24:\n"
//clang-format on
asm
volatile
(
"cmp %[is_first_k], #1
\n
"
"beq 1f
\n
"
"mov r1, %[output]
\n
"
LOAD_C
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"b 2f
\n
"
"1:
\n
"
"veor.32 q4, q4, q4
\n
"
"pld [%[output]]
\n
"
"veor.32 q5, q4, q4
\n
"
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"veor.32 q6, q4, q4
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"veor.32 q7, q4, q4
\n
"
"2:
\n
"
"cmp %[K], #0
\n
"
"beq 4f
\n
"
"3:
\n
"
"vmla.f32 q4, q0, d4[0]
\n
"
"vld1.32 {d2-d3}, [%[a_ptr]]!
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"vld1.32 {d4-d5}, [%[b_ptr]]!
\n
"
"vmla.f32 q4, q1, d6[0]
\n
"
"subs %[K], %[K], #1
\n
"
"vmla.f32 q5, q1, d6[1]
\n
"
"vld1.32 {d0-d1}, [%[a_ptr]]!
\n
"
"vmla.f32 q6, q1, d7[0]
\n
"
"vmla.f32 q7, q1, d7[1]
\n
"
"bne 3b
\n
"
"4:
\n
"
"cmp %[oddk], #1
\n
"
"beq 5f
\n
"
// Even tail
"vmla.f32 q4, q0, d4[0]
\n
"
"vld1.32 {d2-d3}, [%[a_ptr]]!
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vld1.32 {d6-d7}, [%[b_ptr]]!
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"vmla.f32 q4, q1, d6[0]
\n
"
"vmla.f32 q5, q1, d6[1]
\n
"
"vmla.f32 q6, q1, d7[0]
\n
"
"vmla.f32 q7, q1, d7[1]
\n
"
"b 6f
\n
"
// odd tail
"5:
\n
"
"vmla.f32 q4, q0, d4[0]
\n
"
"vmla.f32 q5, q0, d4[1]
\n
"
"vmla.f32 q6, q0, d5[0]
\n
"
"vmla.f32 q7, q0, d5[1]
\n
"
"6:
\n
"
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
K
]
"+r"
(
K
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
oddk
]
"+r"
(
oddk
),
[
output
]
"+r"
(
output
),
[
n_remain
]
"+r"
(
n_remain
)
:
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"r1"
,
"cc"
,
"memory"
);
#undef LOAD_C
#undef STORE_C
}
}
// namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
sgemm_mk4_pack_4x12
);
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy
//! the weight
void
sgemm_mk4_pack_4x12
::
pack_A
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
)
const
{
megdnn_assert
(
y0
%
4
==
0
&&
ymax
%
4
==
0
,
"M must be time of 4"
);
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
constexpr
int
PACK_C_SIZE
=
4
;
size_t
cp_length
=
(
kmax
-
k0
)
*
PACK_C_SIZE
;
for
(
int
m
=
y0
;
m
<
ymax
;
m
+=
4
)
{
const
float
*
src
=
in
+
(
m
/
PACK_C_SIZE
)
*
ldin
+
k0
*
PACK_C_SIZE
;
memcpy
(
out
,
src
,
cp_length
*
sizeof
(
float
));
out
+=
cp_length
;
}
}
void
sgemm_mk4_pack_4x12
::
pack_B
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose_B
)
const
{
megdnn_assert
(
!
transpose_B
);
megdnn_assert
(
k0
%
4
==
0
&&
kmax
%
4
==
0
,
"K must be time of 4"
);
float
tmpbuff
[
16
]
=
{
0.0
f
};
constexpr
int
PACK_C_SIZE
=
4
;
int
ksize
=
kmax
-
k0
;
int
ksize12
=
ksize
*
12
;
int
ksize4
=
(
ksize
<<
2
);
float
*
outptr_base
=
out
;
float
*
outptr_base4
=
outptr_base
+
(
xmax
-
x0
)
/
12
*
ksize12
;
int
k
=
k0
;
for
(;
k
+
3
<
kmax
;
k
+=
4
)
{
const
float
*
inptr
=
in
+
k
/
PACK_C_SIZE
*
ldin
+
x0
*
PACK_C_SIZE
;
prefetch_3x
(
inptr
);
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
12
<=
xmax
;
x
+=
12
)
{
auto
outptr_interleave
=
outptr
;
transpose_1x12_4_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize12
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
transpose_1x4_4_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
memcpy
(
tmpbuff
,
inptr
,
sizeof
(
float
)
*
(
xmax
-
x
)
*
PACK_C_SIZE
);
auto
outptr_interleave
=
outptr
;
const
float
*
tmp_ptr
=
&
tmpbuff
[
0
];
transpose_1x4_4_s
<
float
>
(
tmp_ptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
outptr_base
+=
12
*
PACK_C_SIZE
;
outptr_base4
+=
4
*
PACK_C_SIZE
;
}
}
void
sgemm_mk4_pack_4x12
::
kern
(
const
float
*
packA
,
const
float
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
float
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
float
*
,
float
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
C_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
constexpr
int
PACK_C_SIZE
=
4
;
constexpr
size_t
A_INTERLEAVE
=
4
;
constexpr
size_t
B_INTERLEAVE
=
12
;
const
int
K12
=
K
*
12
;
const
int
K4
=
K
*
4
;
size_t
m
=
0
;
for
(;
m
<
M
;
m
+=
A_INTERLEAVE
)
{
float
*
output
=
C
+
(
m
/
4
*
LDC
);
size_t
n
=
0
;
const
float
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
kern_4x12
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
);
output
+=
PACK_C_SIZE
*
B_INTERLEAVE
;
cur_packB
+=
K12
;
}
for
(;
n
<
N
;
n
+=
4
)
{
kern_4x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
N
-
n
,
4
));
output
+=
PACK_C_SIZE
*
4
;
cur_packB
+=
K4
;
}
packA
+=
K4
;
}
}
// vim: syntax=cpp.doxygen
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
27ef788f
...
@@ -20,6 +20,7 @@ using namespace armv7;
...
@@ -20,6 +20,7 @@ using namespace armv7;
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoF32
f32
;
AlgoF32
f32
;
AlgoF32MK4Pack4x12
f32_mk4_pack_4x12
;
AlgoF32MK4_4x8
f32_mk4_4x8
;
AlgoF32MK4_4x8
f32_mk4_4x8
;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16K4x16x1
f16_k4x16x1
;
AlgoF16K4x16x1
f16_k4x16x1
;
...
@@ -48,6 +49,7 @@ public:
...
@@ -48,6 +49,7 @@ public:
AlgoPack
()
{
AlgoPack
()
{
all_algos
.
emplace_back
(
&
f32_gemv
);
all_algos
.
emplace_back
(
&
f32_gemv
);
all_algos
.
emplace_back
(
&
f32
);
all_algos
.
emplace_back
(
&
f32
);
all_algos
.
emplace_back
(
&
f32_mk4_pack_4x12
);
all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos
.
emplace_back
(
&
f16_k4x16x1
);
all_algos
.
emplace_back
(
&
f16_k4x16x1
);
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
27ef788f
...
@@ -21,6 +21,7 @@ public:
...
@@ -21,6 +21,7 @@ public:
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
private:
private:
class
AlgoF32
;
// Armv7 F32
class
AlgoF32
;
// Armv7 F32
class
AlgoF32MK4Pack4x12
;
// Armv7 F32 Kernel 4x12 with pack
class
AlgoF32MK4_4x8
;
// Armv7 F32 Kernel 4x8 nopack
class
AlgoF32MK4_4x8
;
// Armv7 F32 Kernel 4x8 nopack
class
AlgoF32Gemv
;
// Armv7 F32 Gemv
class
AlgoF32Gemv
;
// Armv7 F32 Gemv
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
27ef788f
...
@@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
...
@@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#undef cb
#undef cb
}
}
#if MEGDNN_AARCH64
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COL_S1_MK4_PACK_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COL_S1_MK4_PACK_F32
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
4
,
7
},
1
);
get_nchw44_conv_bias_args
({
2
,
4
,
7
},
1
);
#if MEGDNN_AARCH64
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"
);
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"
);
}
#elif MEGDNN_ARMV7
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"
);
#endif
#endif
}
#if MEGDNN_AARCH64
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COL_S2_MK4_PACK_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COL_S2_MK4_PACK_F32
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
3
,
5
,
6
},
2
);
get_nchw44_conv_bias_args
({
3
,
5
,
6
},
2
);
#if MEGDNN_AARCH64
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"
);
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"
);
}
#elif MEGDNN_ARMV7
check_conv_bias
(
args
,
handle
(),
"IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"
);
#endif
#endif
}
/***************************** Conv1x1 Algo Test ***********************/
/***************************** Conv1x1 Algo Test ***********************/
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_F32
)
{
...
@@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
...
@@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif
#endif
}
}
#if MEGDNN_AARCH64
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_MK4_PACK_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_MK4_PACK_F32
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
1
},
1
,
true
,
false
,
false
);
get_nchw44_conv_bias_args
({
1
},
1
,
true
,
false
,
false
);
#if MEGDNN_AARCH64
check_conv_bias
(
args
,
handle
(),
"CONV1x1:AARCH64_F32_MK4_K8X12X1:24"
);
check_conv_bias
(
args
,
handle
(),
"CONV1x1:AARCH64_F32_MK4_K8X12X1:24"
);
}
#elif MEGDNN_ARMV7
check_conv_bias
(
args
,
handle
(),
"CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"
);
#endif
#endif
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_MK4_NO_PACK_F32
)
{
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_MK4_NO_PACK_F32
)
{
using
namespace
conv_bias
;
using
namespace
conv_bias
;
...
...
dnn/test/armv7/matrix_mul.cpp
浏览文件 @
27ef788f
...
@@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) {
...
@@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) {
"ARMV7_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
,
4
);
"ARMV7_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
,
4
);
}
}
TEST_F
(
ARMV7
,
MATRIX_MUL_PACK_MK4
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
"ARMV7_F32_MK4_PACK_4X12"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
}
TEST_F
(
ARMV7
,
MATRIX_MUL_MK4_INT8
)
{
TEST_F
(
ARMV7
,
MATRIX_MUL_MK4_INT8
)
{
std
::
vector
<
matrix_mul
::
TestArg
>
args
;
std
::
vector
<
matrix_mul
::
TestArg
>
args
;
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
,
7
,
10
,
11
})
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
,
7
,
10
,
11
})
...
@@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) {
...
@@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) {
dtype
::
Float32
{});
dtype
::
Float32
{});
}
}
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_PACK_MK4
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
8
);
matrix_mul
::
benchmark_with_contrast
(
handle
(),
args
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
"ARMV7_F32_MK4_PACK_4X12"
,
param
::
MatrixMul
::
Format
::
MK4
,
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{});
}
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_INT16x16x32_MK8
)
{
TEST_F
(
ARMV7
,
BENCHMARK_MATRIX_MUL_INT16x16x32_MK8
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
4
);
auto
args
=
matrix_mul
::
get_benchmark_matmul_mk_packed_args
(
4
);
matrix_mul
::
benchmark_with_contrast
(
matrix_mul
::
benchmark_with_contrast
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录