Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
86cf7490
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
86cf7490
编写于
9月 16, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/aarch64): add quantizeds4 matmul int4x4x16_k8x8x8
GitOrigin-RevId: 781290024466459c292cb46d7c97c5d39454525f
上级
bff0fc61
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
1466 addition
and
3 deletion
+1466
-3
dnn/include/megdnn/oprs/base.h
dnn/include/megdnn/oprs/base.h
+1
-0
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+72
-0
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+13
-0
dnn/src/aarch64/matrix_mul/asm/common.h
dnn/src/aarch64/matrix_mul/asm/common.h
+85
-0
dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h
dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h
+913
-0
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
+109
-0
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
+26
-0
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+2
-0
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+2
-2
dnn/src/common/matrix_mul.cpp
dnn/src/common/matrix_mul.cpp
+4
-0
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+1
-0
dnn/src/naive/matrix_mul/matrix_mul_helper.h
dnn/src/naive/matrix_mul/matrix_mul_helper.h
+36
-0
dnn/src/naive/matrix_mul/opr_impl.cpp
dnn/src/naive/matrix_mul/opr_impl.cpp
+7
-1
dnn/test/aarch64/matrix_mul.cpp
dnn/test/aarch64/matrix_mul.cpp
+106
-0
dnn/test/common/rng.cpp
dnn/test/common/rng.cpp
+28
-0
dnn/test/naive/matrix_mul.cpp
dnn/test/naive/matrix_mul.cpp
+61
-0
未找到文件。
dnn/include/megdnn/oprs/base.h
浏览文件 @
86cf7490
...
...
@@ -88,6 +88,7 @@ enum class AlgoDataType : uint32_t {
QUINT8X8X32
=
1
<<
3
,
INT8X8X16
=
1
<<
4
,
INT16X16X32
=
1
<<
5
,
INT4X4X16
=
1
<<
6
,
};
/*!
...
...
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
86cf7490
...
...
@@ -17,6 +17,7 @@
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
...
...
@@ -1394,4 +1395,75 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
,
int8_t
,
int16_t
,
AlgoDataType
::
INT8X8X16
,
MK4
);
/* ===================== Int4x4x16 K8x8x8 algo ===================== */
namespace
{
void
int4x4x16_k8x8x16_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"int4x4x16_k8x8x8_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
dt_int8
>
(),
Bptr
=
kern_param
.
B
<
dt_int8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int16
>
();
aarch64
::
matmul
::
gemm_s4x4x16_s4_8x8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
aarch64
::
matmul
::
gemm_s4x4x16_s4_8x8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoInt4x4x16K8x8x8
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
A_type
.
enumv
()
==
kern_size_param
.
B_type
.
enumv
()
&&
kern_size_param
.
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
kern_size_param
.
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
(
kern_size_param
.
K
&
1
)
==
0
&&
(
kern_size_param
.
N
&
1
)
==
0
;
}
bool
MatrixMulImpl
::
AlgoInt4x4x16K8x8x8
::
preferred
(
const
KernSizeParam
&
kern_size_param
)
const
{
MEGDNN_MARK_USED_VAR
(
kern_size_param
);
return
true
;
}
size_t
MatrixMulImpl
::
AlgoInt4x4x16K8x8x8
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoInt4x4x16K8x8x8::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
aarch64
::
matmul
::
gemm_s4x4x16_s4_8x8x8
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
gemm_s4x4x16_s4_8x8x8
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt4x4x16K8x8x8
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
int4x4x16_k8x8x16_kern
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt4x4x16K8x8x8
,
megdnn_aarch64_matmul_kern
,
"AlgoInt4x4x16K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_s4x4x16_s4_8x8x8
,
int8_t
,
int16_t
,
AlgoDataType
::
INT4X4X16
,
DEFAULT
);
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
86cf7490
...
...
@@ -192,6 +192,19 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_INT8X8X16_K4X4X16
)
};
class
MatrixMulImpl
::
AlgoInt4x4x16K8x8x8
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH64_INT4X4X16_K8X8X8"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
PackMode
packmode
()
const
override
{
return
PackMode
::
DEFAULT
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
AARCH64_INT4X4X16_K8X8X8
)
};
class
MatrixMulImpl
::
AlgoInt8x8x16MK4_16x12x4
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/aarch64/matrix_mul/asm/common.h
浏览文件 @
86cf7490
...
...
@@ -925,6 +925,42 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1,
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_8x4_1_b_with_shift
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
const
T
*&
inptr4
,
const
T
*&
inptr5
,
const
T
*&
inptr6
,
const
T
*&
inptr7
,
T
*
outptr
)
{
static_assert
(
sizeof
(
T
)
==
1
,
"only support size == 1"
);
asm
volatile
(
"ld1 {v0.s}[0], [%[inptr0]], #4
\n
"
"ld1 {v0.s}[1], [%[inptr1]], #4
\n
"
"ld1 {v0.s}[2], [%[inptr2]], #4
\n
"
"ld1 {v0.s}[3], [%[inptr3]], #4
\n
"
"ld1 {v1.s}[0], [%[inptr4]], #4
\n
"
"ld1 {v1.s}[1], [%[inptr5]], #4
\n
"
"ld1 {v1.s}[2], [%[inptr6]], #4
\n
"
"ld1 {v1.s}[3], [%[inptr7]], #4
\n
"
"shl v2.16b, v0.16b, #4
\n
"
"shl v5.16b, v1.16b, #4
\n
"
"sshr v3.16b, v0.16b, #4
\n
"
// hig
"sshr v4.16b, v2.16b, #4
\n
"
// low
"sshr v6.16b, v1.16b, #4
\n
"
// hig
"sshr v7.16b, v5.16b, #4
\n
"
// low
"zip1 v8.16b, v4.16b, v3.16b
\n
"
"zip2 v9.16b, v4.16b, v3.16b
\n
"
"zip1 v10.16b, v7.16b, v6.16b
\n
"
"zip2 v11.16b, v7.16b, v6.16b
\n
"
"st1 {v8.16b-v11.16b},[%[outptr]],#64"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
inptr4
]
"+r"
(
inptr4
),
[
inptr5
]
"+r"
(
inptr5
),
[
inptr6
]
"+r"
(
inptr6
),
[
inptr7
]
"+r"
(
inptr7
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_8x8_1_h
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1059,6 +1095,7 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1,
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"cc"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
interleave_4x16_1_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
@@ -1772,6 +1809,54 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1,
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_4x8_1_b_with_shift
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
const
T
*&
inptr4
,
const
T
*&
inptr5
,
const
T
*&
inptr6
,
const
T
*&
inptr7
,
T
*&
outptr
)
{
static
int8x16_t
shuffle_idx
=
{
0
,
4
,
8
,
12
,
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
};
static_assert
(
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
,
"transpose_8x4_1_b only support uint8_t and int8_t"
);
asm
volatile
(
"ld1 {v0.s}[0], [%[inptr0]], #4
\n
"
// A1A2A3A4
"ld1 {v0.s}[1], [%[inptr1]], #4
\n
"
// B1B2B3B4
"ld1 {v0.s}[2], [%[inptr2]], #4
\n
"
// C1C2C3C4
"ld1 {v0.s}[3], [%[inptr3]], #4
\n
"
// D1D2D3D4
"ld1 {v1.s}[0], [%[inptr4]], #4
\n
"
// E1E2E3E4
"ld1 {v1.s}[1], [%[inptr5]], #4
\n
"
// F1F2F3F4
"ld1 {v1.s}[2], [%[inptr6]], #4
\n
"
// G1G2G3G4
"ld1 {v1.s}[3], [%[inptr7]], #4
\n
"
// H1H2H3H4
"tbl v2.16b, {v0.16b}, %[shuffle_idx].16b
\n
"
// A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4
"tbl v3.16b, {v1.16b}, %[shuffle_idx].16b
\n
"
// E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4
"zip1 v4.4s, v2.4s, v3.4s
\n
"
// A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2
"zip2 v5.4s, v2.4s, v3.4s
\n
"
// A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4
"shl v6.16b, v4.16b, #4
\n
"
"sshr v7.16b, v4.16b, #4
\n
"
// hig
"sshr v8.16b, v6.16b, #4
\n
"
// low
"shl v9.16b, v5.16b, #4
\n
"
"sshr v10.16b, v5.16b, #4
\n
"
// hig
"sshr v11.16b, v9.16b, #4
\n
"
// low
"zip1 v0.2d,v8.2d,v7.2d
\n
"
"zip2 v1.2d,v8.2d,v7.2d
\n
"
"zip1 v2.2d,v11.2d,v10.2d
\n
"
"zip2 v3.2d,v11.2d,v10.2d
\n
"
"st1 {v0.2d-v3.2d},[%[outptr]],#64
\n
"
:
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
),
[
inptr4
]
"+r"
(
inptr4
),
[
inptr5
]
"+r"
(
inptr5
),
[
inptr6
]
"+r"
(
inptr6
),
[
inptr7
]
"+r"
(
inptr7
),
[
shuffle_idx
]
"+w"
(
shuffle_idx
),
[
outptr
]
"+r"
(
outptr
)
:
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"memory"
);
}
template
<
typename
T
>
static
inline
void
transpose_8x8_1_b
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
...
...
dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h
0 → 100644
浏览文件 @
86cf7490
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/kernel_8x8x8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <inttypes.h>
#include <cstring>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul_s4_4x4x16
{
/**
* Overview of register layout:
*
* +---------+---------+---------+---------+
* |v20[0-15]|v21[0-15]|v22[0-15]|v23[0-15]|
* Rhs +---------+---------+---------+---------+
* Lhs | | |
*
* +--------+ - - - - +---------+---------+---------+---------+
* |v0[0-15]| | v4[0-8] | v8[0-8]| v12[0-8]| v16[0-8]|
* |v1[0-15]| | v5[0-8] | v9[0-8]| v13[0-8]| v17[0-8]|
* |v2[0-15]| | v6[0-8] | v10[0-8]| v14[0-8]| v18[0-8]|
* |v3[0-15]| | v7[0-8] | v11[0-8]| v15[0-8]| v19[0-8]|
* +--------+ - - - - +---------+---------+---------+---------+
*
* Accumulator
*/
static
void
s4_kern_8x8_remain
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
// clang-format off
#define LOAD_LINE(reg_index, n) \
"cmp x8, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #4\n" \
"blt 100" n "f\n" \
"ld1 {v" reg_index ".8h}, [x" n "], #16\n" \
"b 101" n "f\n" \
"100" n ":\n" \
"cmp %w[n_remain], #0\n" \
"blt 101" n "f\n" \
"ld1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"cmp %w[n_remain], #3\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[3], [x" n "], #2\n" \
"cmp %w[n_remain], #4\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[4], [x" n "], #2\n" \
"cmp %w[n_remain], #5\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[5], [x" n "], #2\n" \
"cmp %w[n_remain], #6\n" \
"beq 101" n "f\n" \
"ld1 {v" reg_index ".h}[6], [x" n "], #2\n" \
"101" n ":\n" \
"sub x8, x8, #1\n"
#define LOAD_C \
"mov x8, %x[m_remain]\n" \
LOAD_LINE("24", "0") \
LOAD_LINE("25", "1") \
LOAD_LINE("26", "2") \
LOAD_LINE("27", "3") \
LOAD_LINE("28", "4") \
LOAD_LINE("29", "5") \
LOAD_LINE("30", "6") \
LOAD_LINE("31", "7") \
"105:\n"
#define STORE_LINE(reg_index, n) \
"cmp x8, #0 \n" \
"beq 105f\n" \
"cmp %w[n_remain], #8\n" \
"blt 102" n "f\n" \
"st1 {v" reg_index ".8h}, [x" n "], #16\n" \
"b 103" n "f\n" \
"102" n ":\n" \
"cmp %w[n_remain], #0\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[0], [x" n "], #2\n" \
"cmp %w[n_remain], #1\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[1], [x" n "], #2\n" \
"cmp %w[n_remain], #2\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[2], [x" n "], #2\n" \
"cmp %w[n_remain], #3\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[3], [x" n "], #2\n" \
"cmp %w[n_remain], #4\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[4], [x" n "], #2\n" \
"cmp %w[n_remain], #5\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[5], [x" n "], #2\n" \
"cmp %w[n_remain], #6\n" \
"beq 103" n "f\n" \
"st1 {v" reg_index ".h}[6], [x" n "], #2\n" \
"103" n ":\n" \
"sub x8, x8, #1\n"
#define STORE_C \
"mov x8, %x[m_remain]\n" \
STORE_LINE("24", "0") \
STORE_LINE("25", "1") \
STORE_LINE("26", "2") \
STORE_LINE("27", "3") \
STORE_LINE("28", "4") \
STORE_LINE("29", "5") \
STORE_LINE("30", "6") \
STORE_LINE("31", "7") \
"105:\n"
// clang-format on
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"add x1, x0, %x[LDC]
\n
"
"add x2, x1, %x[LDC]
\n
"
"add x3, x2, %x[LDC]
\n
"
"add x4, x3, %x[LDC]
\n
"
"add x5, x4, %x[LDC]
\n
"
"add x6, x5, %x[LDC]
\n
"
"add x7, x6, %x[LDC]
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
LOAD_C
"b 1f
\n
"
"2:
\n
"
// Clear the C regs.
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"1:
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"dup v0.8b,v20.b[0]
\n
"
"dup v1.8b,v20.b[1]
\n
"
"dup v2.8b,v20.b[2]
\n
"
"dup v3.8b,v20.b[3]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v4.8b,v20.b[4]
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v18.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v19.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v18.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v18.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v18.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v18.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v18.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v18.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v18.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v18.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v19.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v19.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v19.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v19.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v19.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v19.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v19.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v19.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v18.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v19.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v18.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v18.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v18.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v18.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v18.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v18.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v18.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v18.8b
\n
"
"smlal v24.8h, v8.8b, v19.8b
\n
"
"smlal v25.8h, v9.8b, v19.8b
\n
"
"smlal v26.8h, v10.8b, v19.8b
\n
"
"smlal v27.8h, v11.8b, v19.8b
\n
"
"smlal v28.8h, v12.8b, v19.8b
\n
"
"smlal v29.8h, v13.8b, v19.8b
\n
"
"smlal v30.8h, v14.8b, v19.8b
\n
"
"smlal v31.8h, v15.8b, v19.8b
\n
"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"3:
\n
"
// Store back into memory
STORE_C
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
//,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
static
void
s4_kern_8x8
(
const
int8_t
*
packA
,
const
int8_t
*
packB
,
int
K
,
int16_t
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
K
/=
8
;
LDC
=
LDC
*
sizeof
(
int16_t
);
const
int8_t
*
a_ptr
=
packA
;
const
int8_t
*
b_ptr
=
packB
;
// clang-format off
#define LOAD_C_8 \
"ld1 {v24.8h}, [x0], #16\n" \
"ld1 {v25.8h}, [x1], #16\n" \
"ld1 {v26.8h}, [x2], #16\n" \
"ld1 {v27.8h}, [x3], #16\n" \
"ld1 {v28.8h}, [x4], #16\n" \
"ld1 {v29.8h}, [x5], #16\n" \
"ld1 {v30.8h}, [x6], #16\n" \
"ld1 {v31.8h}, [x7], #16\n" \
#define STORE_C_8 \
"st1 {v24.8h}, [x0], #16\n" \
"st1 {v25.8h}, [x1], #16\n" \
"st1 {v26.8h}, [x2], #16\n" \
"st1 {v27.8h}, [x3], #16\n" \
"st1 {v28.8h}, [x4], #16\n" \
"st1 {v29.8h}, [x5], #16\n" \
"st1 {v30.8h}, [x6], #16\n" \
"st1 {v31.8h}, [x7], #16\n" \
// clang-format on
register
int16_t
*
outptr
asm
(
"x0"
)
=
output
;
asm
volatile
(
"add x1, x0, %x[LDC]
\n
"
"add x2, x1, %x[LDC]
\n
"
"add x3, x2, %x[LDC]
\n
"
"add x4, x3, %x[LDC]
\n
"
"add x5, x4, %x[LDC]
\n
"
"add x6, x5, %x[LDC]
\n
"
"add x7, x6, %x[LDC]
\n
"
"cmp %w[is_first_k], #1
\n
"
"beq 2f
\n
"
LOAD_C_8
"b 1f
\n
"
"2:
\n
"
// Clear the C regs.
"eor v24.16b, v24.16b, v24.16b
\n
"
"eor v25.16b, v25.16b, v25.16b
\n
"
"eor v26.16b, v26.16b, v26.16b
\n
"
"eor v27.16b, v27.16b, v27.16b
\n
"
"eor v28.16b, v28.16b, v28.16b
\n
"
"eor v29.16b, v29.16b, v29.16b
\n
"
"eor v30.16b, v30.16b, v30.16b
\n
"
"eor v31.16b, v31.16b, v31.16b
\n
"
// General loop.
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"PRFM PLDL1KEEP, [%[a_ptr], #512]
\n
"
"PRFM PLDL1KEEP, [%[b_ptr], #512]
\n
"
"1:
\n
"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
"dup v0.8b,v20.b[0]
\n
"
"ld1 {v22.16b}, [%[a_ptr]],#16
\n
"
"dup v1.8b,v20.b[1]
\n
"
"ld1 {v23.16b}, [%[a_ptr]],#16
\n
"
"dup v2.8b,v20.b[2]
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v3.8b,v20.b[3]
\n
"
"dup v4.8b,v20.b[4]
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v5.8b,v20.b[5]
\n
"
"dup v6.8b,v20.b[6]
\n
"
"dup v7.8b,v20.b[7]
\n
"
"dup v8.8b,v20.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v20.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v20.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v20.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v20.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v20.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v20.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v20.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v21.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v21.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v21.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v21.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v21.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v21.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v21.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v21.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v21.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v21.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v21.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v21.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v21.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v21.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v21.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v21.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v22.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v22.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v22.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v22.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v22.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v22.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v22.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v22.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v22.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v22.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v22.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v22.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v22.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v22.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v22.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v22.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v16.8b}, [%[b_ptr]], 8
\n
"
"dup v0.8b,v23.b[0]
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"dup v1.8b,v23.b[1]
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"dup v2.8b,v23.b[2]
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"dup v3.8b,v23.b[3]
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"dup v4.8b,v23.b[4]
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"dup v5.8b,v23.b[5]
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"dup v6.8b,v23.b[6]
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"dup v7.8b,v23.b[7]
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
"ld1 {v17.8b}, [%[b_ptr]], 8
\n
"
"dup v8.8b,v23.b[8]
\n
"
"smlal v24.8h, v0.8b, v16.8b
\n
"
"dup v9.8b,v23.b[9]
\n
"
"smlal v25.8h, v1.8b, v16.8b
\n
"
"dup v10.8b,v23.b[10]
\n
"
"smlal v26.8h, v2.8b, v16.8b
\n
"
"dup v11.8b,v23.b[11]
\n
"
"smlal v27.8h, v3.8b, v16.8b
\n
"
"dup v12.8b,v23.b[12]
\n
"
"smlal v28.8h, v4.8b, v16.8b
\n
"
"dup v13.8b,v23.b[13]
\n
"
"smlal v29.8h, v5.8b, v16.8b
\n
"
"dup v14.8b,v23.b[14]
\n
"
"smlal v30.8h, v6.8b, v16.8b
\n
"
"dup v15.8b,v23.b[15]
\n
"
"smlal v31.8h, v7.8b, v16.8b
\n
"
"ld1 {v20.16b}, [%[a_ptr]],#16
\n
"
"smlal v24.8h, v8.8b, v17.8b
\n
"
"smlal v25.8h, v9.8b, v17.8b
\n
"
"smlal v26.8h, v10.8b, v17.8b
\n
"
"smlal v27.8h, v11.8b, v17.8b
\n
"
"ld1 {v21.16b}, [%[a_ptr]],#16
\n
"
"smlal v28.8h, v12.8b, v17.8b
\n
"
"smlal v29.8h, v13.8b, v17.8b
\n
"
"smlal v30.8h, v14.8b, v17.8b
\n
"
"smlal v31.8h, v15.8b, v17.8b
\n
"
//"ld1 {v20.16b}, [%[a_ptr]],#16\n"
//"ld1 {v21.16b}, [%[a_ptr]],#16\n"
"subs %w[K], %w[K], #1
\n
"
"cbnz %w[K], 1b
\n
"
"3:
\n
"
// Store back into memory
STORE_C_8
:
[
a_ptr
]
"+r"
(
a_ptr
),
[
b_ptr
]
"+r"
(
b_ptr
),
[
is_first_k
]
"+r"
(
is_first_k
),
[
K
]
"+r"
(
K
),
[
LDC
]
"+r"
(
LDC
),
[
outptr
]
"+r"
(
outptr
),
[
m_remain
]
"+r"
(
m_remain
),
[
n_remain
]
"+r"
(
n_remain
)
//,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
:
"cc"
,
"memory"
,
"x1"
,
"x2"
,
"x3"
,
"x4"
,
"x5"
,
"x6"
,
"x7"
,
"x8"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"v28"
,
"v29"
,
"v30"
,
"v31"
);
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
//packa
static
void
gemm_s4x4x16_8x8x8_transpose_pack
(
dt_int8
*
outptr
,
const
dt_int8
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
int8_t
zerobuff
[
8
];
int8_t
tmpbuff0
[
8
];
int8_t
tmpbuff1
[
8
];
int8_t
tmpbuff2
[
8
];
int8_t
tmpbuff3
[
8
];
int8_t
tmpbuff4
[
8
];
int8_t
tmpbuff5
[
8
];
int8_t
tmpbuff6
[
8
];
int8_t
tmpbuff7
[
8
];
std
::
memset
(
zerobuff
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff0
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff1
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff2
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff3
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff4
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff5
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff6
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff7
,
0
,
sizeof
(
int8_t
)
*
8
);
ldin
/=
2
;
int
y
=
y0
;
for
(;
y
+
7
<
ymax
;
y
+=
8
)
{
const
int8_t
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
const
int8_t
*
inptr2
=
inptr1
+
ldin
;
const
int8_t
*
inptr3
=
inptr2
+
ldin
;
const
int8_t
*
inptr4
=
inptr3
+
ldin
;
const
int8_t
*
inptr5
=
inptr4
+
ldin
;
const
int8_t
*
inptr6
=
inptr5
+
ldin
;
const
int8_t
*
inptr7
=
inptr6
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
prefetch_2x
(
inptr2
);
prefetch_2x
(
inptr3
);
prefetch_2x
(
inptr4
);
prefetch_2x
(
inptr5
);
prefetch_2x
(
inptr6
);
prefetch_2x
(
inptr7
);
int
K
=
(
kmax
-
k0
)
/
2
;
//! read 4 * 16 in each row
for
(;
K
>
3
;
K
-=
4
)
{
transpose_4x8_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
}
if
(
K
>
0
)
{
std
::
memcpy
(
tmpbuff0
,
inptr0
,
K
);
std
::
memcpy
(
tmpbuff1
,
inptr1
,
K
);
std
::
memcpy
(
tmpbuff2
,
inptr2
,
K
);
std
::
memcpy
(
tmpbuff3
,
inptr3
,
K
);
std
::
memcpy
(
tmpbuff4
,
inptr4
,
K
);
std
::
memcpy
(
tmpbuff5
,
inptr5
,
K
);
std
::
memcpy
(
tmpbuff6
,
inptr6
,
K
);
std
::
memcpy
(
tmpbuff7
,
inptr7
,
K
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
inptr2
=
tmpbuff2
;
inptr3
=
tmpbuff3
;
inptr4
=
tmpbuff4
;
inptr5
=
tmpbuff5
;
inptr6
=
tmpbuff6
;
inptr7
=
tmpbuff7
;
transpose_4x8_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
}
}
for
(;
y
<
ymax
;
y
+=
8
)
{
const
int8_t
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
const
int8_t
*
inptr2
=
inptr1
+
ldin
;
const
int8_t
*
inptr3
=
inptr2
+
ldin
;
const
int8_t
*
inptr4
=
inptr3
+
ldin
;
const
int8_t
*
inptr5
=
inptr4
+
ldin
;
const
int8_t
*
inptr6
=
inptr5
+
ldin
;
const
int8_t
*
inptr7
=
inptr6
+
ldin
;
int
K
=
(
kmax
-
k0
)
/
2
;
//! read 4 * 16 in each row
for
(;
K
>
3
;
K
-=
4
)
{
if
(
y
+
7
>=
ymax
)
{
switch
(
y
+
7
-
ymax
)
{
case
6
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
5
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
4
:
inptr3
=
zerobuff
;
MEGDNN_FALLTHRU
case
3
:
inptr4
=
zerobuff
;
MEGDNN_FALLTHRU
case
2
:
inptr5
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr6
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr7
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
transpose_4x8_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
}
if
(
K
>
0
)
{
if
(
y
+
7
>=
ymax
)
{
switch
(
y
+
7
-
ymax
)
{
case
6
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
5
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
4
:
inptr3
=
zerobuff
;
MEGDNN_FALLTHRU
case
3
:
inptr4
=
zerobuff
;
MEGDNN_FALLTHRU
case
2
:
inptr5
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr6
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr7
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
std
::
memcpy
(
tmpbuff0
,
inptr0
,
K
);
std
::
memcpy
(
tmpbuff1
,
inptr1
,
K
);
std
::
memcpy
(
tmpbuff2
,
inptr2
,
K
);
std
::
memcpy
(
tmpbuff3
,
inptr3
,
K
);
std
::
memcpy
(
tmpbuff4
,
inptr4
,
K
);
std
::
memcpy
(
tmpbuff5
,
inptr5
,
K
);
std
::
memcpy
(
tmpbuff6
,
inptr6
,
K
);
std
::
memcpy
(
tmpbuff7
,
inptr7
,
K
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
inptr2
=
tmpbuff2
;
inptr3
=
tmpbuff3
;
inptr4
=
tmpbuff4
;
inptr5
=
tmpbuff5
;
inptr6
=
tmpbuff6
;
inptr7
=
tmpbuff7
;
transpose_4x8_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr
);
}
}
}
//packb
static
void
gemm_s4x4x16_8x8x8_interleave_pack
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
int8_t
zerobuff
[
8
];
int8_t
tmpbuff0
[
8
];
int8_t
tmpbuff1
[
8
];
int8_t
tmpbuff2
[
8
];
int8_t
tmpbuff3
[
8
];
int8_t
tmpbuff4
[
8
];
int8_t
tmpbuff5
[
8
];
int8_t
tmpbuff6
[
8
];
int8_t
tmpbuff7
[
8
];
std
::
memset
(
zerobuff
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff0
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff1
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff2
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff3
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff4
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff5
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff6
,
0
,
sizeof
(
int8_t
)
*
8
);
std
::
memset
(
tmpbuff7
,
0
,
sizeof
(
int8_t
)
*
8
);
const
int
ksize
=
kmax
-
k0
;
const
int
ksize8
=
round_up
(
ksize
,
8
)
*
8
;
//pack to int8 *8 packto s4 *4
int8_t
*
outptr
=
out
;
int8_t
*
outptr_interleave
=
nullptr
;
int
k
=
k0
;
ldin
/=
2
;
xmax
=
xmax
/
2
;
for
(;
k
+
7
<
kmax
;
k
+=
8
)
{
const
int8_t
*
inptr0
=
in
+
k
*
ldin
+
x0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
const
int8_t
*
inptr2
=
inptr1
+
ldin
;
const
int8_t
*
inptr3
=
inptr2
+
ldin
;
const
int8_t
*
inptr4
=
inptr3
+
ldin
;
const
int8_t
*
inptr5
=
inptr4
+
ldin
;
const
int8_t
*
inptr6
=
inptr5
+
ldin
;
const
int8_t
*
inptr7
=
inptr6
+
ldin
;
prefetch_2x
(
inptr0
);
prefetch_2x
(
inptr1
);
prefetch_2x
(
inptr2
);
prefetch_2x
(
inptr3
);
prefetch_2x
(
inptr4
);
prefetch_2x
(
inptr5
);
prefetch_2x
(
inptr6
);
prefetch_2x
(
inptr7
);
int
x
=
x0
;
int8_t
*
outptr_inner
=
outptr
;
for
(;
x
+
3
<
xmax
;
x
+=
4
)
{
outptr_interleave
=
outptr_inner
;
interleave_8x4_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr_interleave
);
outptr_inner
+=
ksize8
;
}
if
(
x
<
xmax
)
{
int
remainx
=
xmax
-
x
;
std
::
memcpy
(
tmpbuff0
,
inptr0
,
remainx
);
std
::
memcpy
(
tmpbuff1
,
inptr1
,
remainx
);
std
::
memcpy
(
tmpbuff2
,
inptr2
,
remainx
);
std
::
memcpy
(
tmpbuff3
,
inptr3
,
remainx
);
std
::
memcpy
(
tmpbuff4
,
inptr4
,
remainx
);
std
::
memcpy
(
tmpbuff5
,
inptr5
,
remainx
);
std
::
memcpy
(
tmpbuff6
,
inptr6
,
remainx
);
std
::
memcpy
(
tmpbuff7
,
inptr7
,
remainx
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
inptr2
=
tmpbuff2
;
inptr3
=
tmpbuff3
;
inptr4
=
tmpbuff4
;
inptr5
=
tmpbuff5
;
inptr6
=
tmpbuff6
;
inptr7
=
tmpbuff7
;
outptr_interleave
=
outptr_inner
;
interleave_8x4_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr_interleave
);
outptr_inner
+=
ksize8
;
}
outptr
+=
64
;
}
if
(
k
<
kmax
)
{
const
int8_t
*
inptr0
=
in
+
k
*
ldin
+
x0
;
const
int8_t
*
inptr1
=
inptr0
+
ldin
;
const
int8_t
*
inptr2
=
inptr1
+
ldin
;
const
int8_t
*
inptr3
=
inptr2
+
ldin
;
const
int8_t
*
inptr4
=
inptr3
+
ldin
;
const
int8_t
*
inptr5
=
inptr4
+
ldin
;
const
int8_t
*
inptr6
=
inptr5
+
ldin
;
const
int8_t
*
inptr7
=
inptr6
+
ldin
;
int
k_remain
=
kmax
-
k
-
1
;
int
x
=
x0
;
int8_t
*
outptr_inner
=
outptr
;
for
(;
x
+
3
<
xmax
;
x
+=
4
)
{
switch
(
k_remain
)
{
case
0
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
2
:
inptr3
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
3
:
inptr4
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
4
:
inptr5
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
5
:
inptr6
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
6
:
inptr7
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
break
;
}
outptr_interleave
=
outptr_inner
;
interleave_8x4_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr_interleave
);
outptr_inner
+=
ksize8
;
}
if
(
x
<
xmax
)
{
switch
(
k_remain
)
{
case
0
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
2
:
inptr3
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
3
:
inptr4
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
4
:
inptr5
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
5
:
inptr6
=
zerobuff
;
MEGDNN_FALLTHRU
;
case
6
:
inptr7
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
break
;
}
int
remainx
=
xmax
-
x
;
outptr_interleave
=
outptr_inner
;
std
::
memcpy
(
tmpbuff0
,
inptr0
,
remainx
);
std
::
memcpy
(
tmpbuff1
,
inptr1
,
remainx
);
std
::
memcpy
(
tmpbuff2
,
inptr2
,
remainx
);
std
::
memcpy
(
tmpbuff3
,
inptr3
,
remainx
);
std
::
memcpy
(
tmpbuff4
,
inptr4
,
remainx
);
std
::
memcpy
(
tmpbuff5
,
inptr5
,
remainx
);
std
::
memcpy
(
tmpbuff6
,
inptr6
,
remainx
);
std
::
memcpy
(
tmpbuff7
,
inptr7
,
remainx
);
inptr0
=
tmpbuff0
;
inptr1
=
tmpbuff1
;
inptr2
=
tmpbuff2
;
inptr3
=
tmpbuff3
;
inptr4
=
tmpbuff4
;
inptr5
=
tmpbuff5
;
inptr6
=
tmpbuff6
;
inptr7
=
tmpbuff7
;
outptr_interleave
=
outptr_inner
;
interleave_8x4_1_b_with_shift
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
inptr4
,
inptr5
,
inptr6
,
inptr7
,
outptr_interleave
);
outptr_inner
+=
ksize8
;
}
}
}
}
// namespace matmul_4x4x16
}
// namespace aarch64
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
0 → 100644
浏览文件 @
86cf7490
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
using
namespace
megdnn
;
using
namespace
aarch64
;
using
namespace
aarch64
::
matmul
;
// ===========================gemm_s4x4x16_s4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gemm_s4x4x16_s4_8x8x8
);
void
gemm_s4x4x16_s4_8x8x8
::
pack_A
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
matmul_s4_4x4x16
::
gemm_s4x4x16_8x8x8_interleave_pack
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
else
{
matmul_s4_4x4x16
::
gemm_s4x4x16_8x8x8_transpose_pack
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
}
void
gemm_s4x4x16_s4_8x8x8
::
pack_B
(
dt_int8
*
out
,
const
dt_int8
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose
)
const
{
if
(
transpose
)
{
matmul_s4_4x4x16
::
gemm_s4x4x16_8x8x8_transpose_pack
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
else
{
matmul_s4_4x4x16
::
gemm_s4x4x16_8x8x8_interleave_pack
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
}
void
gemm_s4x4x16_s4_8x8x8
::
kern
(
const
dt_int8
*
packA
,
const
dt_int8
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
dt_int16
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
dt_int16
*
,
dt_int16
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
(
A_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
C_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
),
"A: %s B: %s C: %s"
,
A_dtype
.
name
(),
B_dtype
.
name
(),
C_dtype
.
name
());
MEGDNN_MARK_USED_VAR
(
A_dtype
);
MEGDNN_MARK_USED_VAR
(
B_dtype
);
MEGDNN_MARK_USED_VAR
(
C_dtype
);
constexpr
size_t
A_INTERLEAVE
=
8
;
constexpr
size_t
B_INTERLEAVE
=
8
;
//! K is packed to times of 8
K
=
round_up
<
size_t
>
(
K
,
8
);
const
int
K8
=
K
*
8
;
size_t
m
=
0
;
for
(;
m
+
A_INTERLEAVE
-
1
<
M
;
m
+=
A_INTERLEAVE
)
{
int16_t
*
output
=
C
+
(
m
*
LDC
);
size_t
n
=
0
;
const
dt_int8
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_s4_4x4x16
::
s4_kern_8x8
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
A_INTERLEAVE
,
B_INTERLEAVE
);
output
+=
B_INTERLEAVE
;
cur_packB
+=
K8
;
}
for
(;
n
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_s4_4x4x16
::
s4_kern_8x8_remain
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
A_INTERLEAVE
,
std
::
min
<
size_t
>
(
N
-
n
,
B_INTERLEAVE
));
output
+=
B_INTERLEAVE
;
cur_packB
+=
K8
;
}
packA
+=
K8
;
}
for
(;
m
<
M
;
m
+=
A_INTERLEAVE
)
{
int16_t
*
output
=
C
+
(
m
*
LDC
);
size_t
n
=
0
;
const
dt_int8
*
cur_packB
=
packB
;
for
(;
n
<
N
;
n
+=
B_INTERLEAVE
)
{
matmul_s4_4x4x16
::
s4_kern_8x8_remain
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
M
-
m
,
A_INTERLEAVE
),
std
::
min
<
size_t
>
(
N
-
n
,
B_INTERLEAVE
));
output
+=
B_INTERLEAVE
;
cur_packB
+=
K8
;
}
packA
+=
K8
;
}
}
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
0 → 100644
浏览文件 @
86cf7490
/**
* \file dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul
{
MEGDNN_REG_GEMM_STRATEGY
(
dt_int8
,
dt_int16
,
dt_int16
,
8
,
8
,
8
,
false
,
true
,
gemm_s4x4x16_s4_8x8x8
);
}
// namespace matmul
}
// namespace aarch64
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
86cf7490
...
...
@@ -50,6 +50,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#else
AlgoQuint8K8x8x8
quint8_k8x8x8
;
#endif
AlgoInt4x4x16K8x8x8
int4x4x16_k8x8x8
;
SmallVector
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
m_all_algos
;
fallback
::
MatrixMulImpl
::
AlgoBase
::
Mapper
m_all_algos_map
;
...
...
@@ -87,6 +88,7 @@ public:
#else
m_all_algos
.
emplace_back
(
&
quint8_k8x8x8
);
#endif
m_all_algos
.
emplace_back
(
&
int4x4x16_k8x8x8
);
for
(
auto
&&
algo
:
m_all_algos
)
{
m_all_algos_map
.
emplace
(
algo
->
info
().
desc
,
algo
);
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
86cf7490
...
...
@@ -66,8 +66,8 @@ private:
#else
class
AlgoQuint8K8x8x8
;
// Aarch64 Quint8 Kernel 8x8x8
#endif
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int
4x4
x16 Kernel 4x4x16
class
AlgoInt8x8x16MK4_K8x8x8
;
// Aarch64 Int
8x8
x16 Kernel 4x4x16
class
AlgoInt4x4x16K8x8x8
;
// Aarch64 Int4x4x16 Kernel 4x4x16
class
AlgoPack
;
public:
static
const
AlgoPack
&
algo_pack
();
...
...
dnn/src/common/matrix_mul.cpp
浏览文件 @
86cf7490
...
...
@@ -33,6 +33,8 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) {
C_candi
=
dtype
::
QuantizedS32
(
mul_scale
(
A
,
B
));
}
else
if
(
A
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
C_candi
=
dtype
::
QuantizedS32
(
mul_scale
(
A
,
B
));
}
else
if
(
A
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
C_candi
=
dtype
::
QuantizedS16
(
mul_scale
(
A
,
B
));
}
if
(
!
C
.
valid
())
{
C
=
C_candi
;
...
...
@@ -169,6 +171,8 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B,
A
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
A
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
megdnn_assert
(
C
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
);
}
else
if
(
A
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
){
megdnn_assert
(
C
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
);
}
megdnn_assert
(
param
().
compute_mode
!=
Param
::
ComputeMode
::
FLOAT32
MEGDNN_INC_FLOAT16
(
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
86cf7490
...
...
@@ -154,6 +154,7 @@ public:
AARCH64_QUINT8_K8X8X4_DOTPROD
,
AARCH64_QUINT8_GEMV_DOTPROD
,
AARCH64_QUINT8_K8X8X8
,
AARCH64_INT4X4X16_K8X8X8
,
#else
ARMV7_F32
=
1
<<
16
,
ARMV7_F32_MK4_PACK_4X12
,
...
...
dnn/src/naive/matrix_mul/matrix_mul_helper.h
浏览文件 @
86cf7490
...
...
@@ -179,6 +179,42 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A,
C
.
compatible_ptr
<
dt_int32
>
(),
M
,
N
,
K
,
LDA
,
LDB
,
LDC
,
nA
.
layout
.
dtype
,
nB
.
layout
.
dtype
);
}
template
<
bool
transA
,
bool
transB
>
void
exec_matrix_mul_qint4x4x16_helper
(
_megdnn_tensor_in
A
,
_megdnn_tensor_in
B
,
_megdnn_tensor_out
C
,
_megdnn_workspace
workspace
,
const
param
::
MatrixMul
&
param
)
{
auto
convert_layout
=
[](
const
TensorLayout
&
layout
)
{
auto
ret
=
layout
;
auto
param
=
layout
.
dtype
.
param
<
dtype
::
QuantizedS4
>
();
ret
.
dtype
=
dtype
::
QuantizedS8
(
param
.
scale
);
return
ret
;
};
TensorND
nA
=
{
workspace
.
raw_ptr
,
convert_layout
(
A
.
layout
)};
TensorND
nB
=
{
workspace
.
raw_ptr
+
nA
.
layout
.
span
().
dist_byte
(),
convert_layout
(
B
.
layout
)};
auto
convert_4to8
=
[](
const
TensorND
&
in
,
const
TensorND
&
out
)
{
auto
ptr
=
static_cast
<
int8_t
*>
(
in
.
raw_ptr
)
+
in
.
layout
.
span
().
low_byte
;
auto
out_ptr
=
out
.
compatible_ptr
<
int8_t
>
()
+
out
.
layout
.
span
().
low_byte
;
for
(
size_t
i
=
0
;
i
<
in
.
layout
.
span
().
dist_elem
();
i
+=
2
)
{
int8_t
cur
=
ptr
[
i
/
2
];
out_ptr
[
i
]
=
cur
<<
4
;
out_ptr
[
i
]
=
out_ptr
[
i
]
>>
4
;
out_ptr
[
i
+
1
]
=
cur
>>
4
;
}
};
convert_4to8
(
A
,
nA
);
convert_4to8
(
B
,
nB
);
auto
M
=
C
.
layout
.
shape
[
0
],
N
=
C
.
layout
.
shape
[
1
];
auto
K
=
A
.
layout
.
shape
[
param
.
transposeA
?
0
:
1
];
auto
LDA
=
A
.
layout
.
stride
[
0
],
LDB
=
B
.
layout
.
stride
[
0
],
LDC
=
C
.
layout
.
stride
[
0
];
run_matrix_mul_tpl
<
int8_t
,
dt_int16
,
transA
,
transB
,
dt_int16
>
(
nA
.
compatible_ptr
<
int8_t
>
(),
nB
.
compatible_ptr
<
int8_t
>
(),
C
.
compatible_ptr
<
dt_int16
>
(),
M
,
N
,
K
,
LDA
,
LDB
,
LDC
,
nA
.
layout
.
dtype
,
nB
.
layout
.
dtype
);
}
}
// namespace naive
}
// namespace megdnn
...
...
dnn/src/naive/matrix_mul/opr_impl.cpp
浏览文件 @
86cf7490
...
...
@@ -26,7 +26,8 @@ size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
MIDOUT_BEGIN
(
megdnn_naive_matmul
,
midout_iv
(
"MatrixMulForwardImpl::get_workspace_in_bytes"
_hash
))
{
if
(
A
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
if
(
A
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
||
A
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
return
(
A
.
span
().
dist_elem
()
+
B
.
span
().
dist_elem
())
*
sizeof
(
uint8_t
);
}
...
...
@@ -104,6 +105,11 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B,
param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
)
{
exec_matrix_mul_quint4x4x32_helper
<
TA
,
TB
>
(
A
,
B
,
C
,
workspace
,
param
);
return
;
}
else
if
(
A
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
C
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS16
&&
param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
)
{
exec_matrix_mul_qint4x4x16_helper
<
TA
,
TB
>
(
A
,
B
,
C
,
workspace
,
param
);
return
;
}
#undef cb
megdnn_throw
(
ssprintf
(
...
...
dnn/test/aarch64/matrix_mul.cpp
浏览文件 @
86cf7490
...
...
@@ -164,6 +164,55 @@ TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_K4x4x16) {
handle
(),
"AARCH64_INT8X8X16_K4X4X16"
);
}
TEST_F
(
AARCH64
,
MATRIX_MUL_INT4x4x16_K8x8x8_QUANTIZEDS4
)
{
param
::
MatrixMul
param
;
param
.
transposeA
=
false
;
param
.
transposeB
=
false
;
Checker
<
MatrixMul
>
checker
(
handle
());
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
0.6
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
0.5
})
.
set_dtype
(
2
,
dtype
::
QuantizedS16
{
0.6
*
0.5
})
.
set_param
(
param
);
checker
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH64_INT4X4X16_K8X8X8"
));
auto
run
=
[
&
](
size_t
M
,
size_t
N
,
size_t
K
)
{
printf
(
"M N K %zu %zu %zu
\n
"
,
M
,
N
,
K
);
TensorShape
A
,
B
;
if
(
param
.
transposeA
)
{
A
=
TensorShape
{
K
,
M
};
}
else
{
A
=
TensorShape
{
M
,
K
};
}
if
(
param
.
transposeB
)
{
B
=
TensorShape
{
N
,
K
};
}
else
{
B
=
TensorShape
{
K
,
N
};
}
checker
.
exec
({
A
,
B
,
{}});
};
for
(
size_t
m
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
12
,
16
,
20
})
for
(
size_t
n
:
{
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
24
})
for
(
size_t
k
:
{
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
32
})
run
(
m
,
n
,
k
);
for
(
size_t
k
=
4
;
k
<=
256
;
k
*=
8
)
{
for
(
size_t
m
=
4
;
m
<=
256
;
m
*=
4
)
{
for
(
size_t
n
=
4
;
n
<=
256
;
n
*=
4
)
{
run
(
m
,
n
,
k
);
}
}
}
param
.
transposeA
=
true
;
run
(
8
,
8
,
8
);
run
(
16
,
8
,
16
);
param
.
transposeB
=
true
;
run
(
8
,
8
,
8
);
run
(
16
,
16
,
16
);
}
TEST_F
(
AARCH64
,
MATRIX_MUL_INT16x16x32_K12X8X1
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Int16
{},
dtype
::
Int16
{},
dtype
::
Int32
{},
handle
(),
"AARCH64_INT16X16X32_K12X8X1"
);
...
...
@@ -410,6 +459,63 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
run
(
384
,
384
,
384
);
}
TEST_F
(
AARCH64
,
BENCHMARK_4x4x16_vs_8x8x16
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
param
.
transposeA
=
false
;
param
.
transposeB
=
false
;
Benchmarker
<
MatrixMul
>
benchmarker
(
handle
());
Benchmarker
<
MatrixMul
>
benchmarker_int4_4x4x16
(
handle
());
benchmarker_int4_4x4x16
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
QuantizedS4
{
0.3
})
.
set_dtype
(
1
,
dtype
::
QuantizedS4
{
0.3
})
.
set_dtype
(
2
,
dtype
::
QuantizedS16
{
0.09
})
.
set_param
(
param
)
.
set_display
(
false
);
benchmarker
.
set_times
(
RUNS
)
.
set_dtype
(
0
,
dtype
::
Int8
{})
.
set_dtype
(
1
,
dtype
::
Int8
{})
.
set_dtype
(
2
,
dtype
::
Int16
{})
.
set_param
(
param
)
.
set_display
(
false
);
benchmarker
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"AARCH64_INT8X8X16_K4X4X16"
));
auto
run
=
[
&
](
size_t
M
,
size_t
N
,
size_t
K
)
{
auto
default_used
=
benchmarker
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
auto
int4416_used
=
benchmarker_int4_4x4x16
.
exec
({{
M
,
K
},
{
K
,
N
},
{}})
/
RUNS
;
float
computations
=
2.
f
*
M
*
K
*
N
*
1e-6
;
printf
(
"run: {%zu{M} %zu{K} %zu{N}} normal 8x8x16 used: %f ms %f "
"Gflops int4416 used %f int4416_gflops %f speedup %f
\n
"
,
M
,
K
,
N
,
default_used
,
computations
/
default_used
,
int4416_used
,
computations
/
int4416_used
,
default_used
/
int4416_used
);
};
for
(
int
m
=
32
;
m
<=
1024
;
m
+=
32
)
for
(
int
n
=
32
;
n
<=
1024
;
n
+=
32
)
for
(
int
k
=
32
;
k
<=
512
;
k
+=
32
)
run
(
m
,
n
,
k
);
run
(
32
,
32
,
32
);
run
(
32
,
32
,
8
);
run
(
32
,
32
,
16
);
run
(
32
,
32
,
24
);
run
(
32
*
2
,
32
*
2
,
32
);
run
(
32
*
4
,
32
*
4
,
32
);
run
(
32
*
6
,
32
*
6
,
32
);
run
(
32
*
8
,
32
*
8
,
32
);
run
(
32
*
2
,
32
*
2
,
32
*
2
);
run
(
32
*
4
,
32
*
4
,
32
*
3
);
run
(
32
*
6
,
32
*
6
,
32
*
4
);
run
(
32
*
8
,
32
*
8
,
32
*
5
);
run
(
32
*
10
,
32
*
10
,
32
*
10
);
run
(
384
,
384
,
384
);
run
(
256
,
256
,
384
);
run
(
512
,
512
,
384
);
run
(
1024
,
1024
,
384
);
}
TEST_F
(
AARCH64
,
BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16
)
{
constexpr
size_t
RUNS
=
50
;
param
::
MatrixMul
param
;
...
...
dnn/test/common/rng.cpp
浏览文件 @
86cf7490
...
...
@@ -183,6 +183,34 @@ void IIDRNG::gen(const TensorND& tensor) {
}
return
;
}
if
(
tensor
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
{
auto
ptr
=
static_cast
<
int8_t
*>
(
tensor
.
raw_ptr
);
if
(
output_is_float
())
{
for
(
size_t
i
=
0
;
i
<
nr_elems
;
i
+=
2
)
{
int8_t
val0
=
tensor
.
layout
.
dtype
.
param
<
dt_qint4
>
()
.
quantize
(
static_cast
<
float
>
(
gen_single_val
()))
.
as_int8
();
int8_t
val1
=
tensor
.
layout
.
dtype
.
param
<
dt_qint4
>
()
.
quantize
(
static_cast
<
float
>
(
gen_single_val
()))
.
as_int8
();
ptr
[(
offset
+
i
)
/
2
]
=
(
val0
&
0xF
)
|
(
val1
<<
4
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
nr_elems
;
i
+=
2
)
{
int8_t
val0
=
static_cast
<
int8_t
>
(
gen_single_val
());
int8_t
val1
=
static_cast
<
int8_t
>
(
gen_single_val
());
val0
=
std
::
min
(
val0
,
DTypeTrait
<
dtype
::
QuantizedS4
>::
max
());
val0
=
std
::
max
(
val0
,
DTypeTrait
<
dtype
::
QuantizedS4
>::
min
());
val1
=
std
::
min
(
val1
,
DTypeTrait
<
dtype
::
QuantizedS4
>::
max
());
val1
=
std
::
max
(
val1
,
DTypeTrait
<
dtype
::
QuantizedS4
>::
min
());
ptr
[(
offset
+
i
)
/
2
]
=
(
val0
&
0xF
)
|
(
val1
<<
4
);
}
}
return
;
}
megdnn_assert
(
0
,
"IIDRNG does not know how to generate value for DType %s"
,
tensor
.
layout
.
dtype
.
name
());
}
...
...
dnn/test/naive/matrix_mul.cpp
浏览文件 @
86cf7490
...
...
@@ -203,6 +203,67 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZED4x4x32) {
});
}
TEST_F
(
NAIVE
,
MATRIX_MUL_QUANTIZEDS4_4x4x16
)
{
Checker
<
MatrixMul
>
checker
(
handle
(),
/* check_dispatch */
false
);
auto
GenTensorValueQuint4
=
[](
const
TensorShape
&
shape
,
dtype
::
QuantizedS4
dtype
,
const
std
::
vector
<
int
>&
values
)
{
TensorND
tensor
;
tensor
.
layout
=
{
shape
,
dtype
};
tensor
.
raw_ptr
=
static_cast
<
dt_byte
*>
(
malloc
(
tensor
.
layout
.
span
().
dist_byte
()));
uint8_t
*
ptr
=
static_cast
<
uint8_t
*>
(
tensor
.
raw_ptr
);
megdnn_assert
(
values
.
size
()
==
tensor
.
layout
.
span
().
dist_elem
());
for
(
size_t
i
=
0
;
i
<
tensor
.
layout
.
span
().
dist_elem
();
i
+=
2
)
{
int
val0
=
values
[
i
],
val1
=
values
[
i
+
1
];
ptr
[
i
/
2
]
=
(
val0
&
0xF
)
|
(
val1
<<
4
);
}
return
tensor
;
};
using
Param
=
MatrixMul
::
Param
;
Param
param
;
checker
.
set_param
(
param
);
checker
.
set_dtype
(
2
,
dtype
::
QuantizedS16
(
0.3
f
*
0.3
f
));
checker
.
exect
(
Testcase
{
GenTensorValueQuint4
(
{
8
,
8
},
dtype
::
QuantizedS4
(
0.3
f
),
{
-
8
,
7
,
2
,
1
,
2
,
3
,
2
,
7
,
2
,
5
,
3
,
3
,
7
,
4
,
-
7
,
1
,
-
5
,
7
,
-
4
,
-
1
,
-
1
,
2
,
4
,
1
,
7
,
2
,
-
6
,
-
2
,
-
6
,
3
,
4
,
4
,
-
2
,
2
,
3
,
0
,
6
,
5
,
3
,
4
,
-
1
,
-
1
,
-
5
,
5
,
2
,
5
,
1
,
4
,
6
,
2
,
0
,
0
,
3
,
2
,
2
,
1
,
-
4
,
-
3
,
7
,
5
,
0
,
3
,
2
,
3
}),
GenTensorValueQuint4
(
{
8
,
8
},
dtype
::
QuantizedS4
(
0.3
f
),
{
5
,
-
8
,
-
7
,
-
6
,
4
,
7
,
-
5
,
-
5
,
-
4
,
7
,
-
3
,
-
2
,
5
,
6
,
4
,
2
,
3
,
-
1
,
2
,
2
,
7
,
3
,
6
,
0
,
5
,
4
,
0
,
2
,
2
,
3
,
3
,
2
,
1
,
-
8
,
-
7
,
-
6
,
0
,
-
5
,
-
4
,
4
,
-
3
,
7
,
1
,
6
,
-
2
,
2
,
-
1
,
5
,
2
,
0
,
7
,
6
,
5
,
4
,
3
,
2
,
0
,
0
,
1
,
0
,
5
,
2
,
2
,
6
}),
{}},
Testcase
{
{},
{},
TensorValue
(
{
8
,
8
},
dtype
::
QuantizedS16
(
0.3
f
*
0.3
f
),
{
-
60
,
120
,
49
,
58
,
58
,
13
,
92
,
125
,
-
5
,
0
,
-
116
,
-
70
,
22
,
9
,
-
14
,
46
,
-
69
,
111
,
44
,
48
,
6
,
19
,
42
,
57
,
-
8
,
25
,
10
,
16
,
26
,
97
,
-
28
,
-
12
,
-
12
,
14
,
2
,
26
,
48
,
7
,
24
,
93
,
-
2
,
45
,
2
,
32
,
-
19
,
-
1
,
-
16
,
72
,
23
,
-
44
,
-
52
,
-
34
,
45
,
53
,
-
28
,
6
,
33
,
45
,
71
,
84
,
47
,
10
,
74
,
61
})
});
}
TEST_F
(
NAIVE
,
MATRIX_MUL_QUANTIZED8x8x32
)
{
Checker
<
MatrixMul
>
checker
(
handle
(),
/* check_dispatch */
false
);
MatrixMul
::
Param
param
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录