Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a450d0f5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a450d0f5
编写于
1月 31, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(fallback): add fp16 mk8 8x8 matmul
GitOrigin-RevId: 1a50a8a7be433f3ba688eeaa1af460458e5f0b53
上级
b85792ac
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
592 addition
and
0 deletion
+592
-0
dnn/src/fallback/matrix_mul/algos.cpp
dnn/src/fallback/matrix_mul/algos.cpp
+59
-0
dnn/src/fallback/matrix_mul/algos.h
dnn/src/fallback/matrix_mul/algos.h
+15
-0
dnn/src/fallback/matrix_mul/generic_strategy.h
dnn/src/fallback/matrix_mul/generic_strategy.h
+6
-0
dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp
dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp
+485
-0
dnn/src/fallback/matrix_mul/opr_impl.cpp
dnn/src/fallback/matrix_mul/opr_impl.cpp
+7
-0
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+2
-0
dnn/test/fallback/matrix_mul.cpp
dnn/test/fallback/matrix_mul.cpp
+18
-0
未找到文件。
dnn/src/fallback/matrix_mul/algos.cpp
浏览文件 @
a450d0f5
...
@@ -297,7 +297,66 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
...
@@ -297,7 +297,66 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const
KernSizeParam
&
)
const
{
const
KernSizeParam
&
)
const
{
return
gi_f32_mk4_4x8_kern
;
return
gi_f32_mk4_4x8_kern
;
}
}
#if defined(GI_SUPPORT_F16)
/* ================== F16 Gemm MK8 gi algo ================== */
namespace
{
void
gi_f16_mk8_8x8_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_fb_gi_matmul_kern
,
midout_iv
(
"gi_f16_mk8_8x8_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
dt_float16
>
(),
Bptr
=
kern_param
.
B
<
dt_float16
>
();
auto
Cptr
=
kern_param
.
C
<
dt_float16
>
();
matmul
::
fallback
::
gi_sgemm_nopack_mk8_8x8_fp16
strategy
(
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
fallback
::
gi_sgemm_nopack_mk8_8x8_fp16
,
false
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoF16GiMK8_8x8
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
constexpr
size_t
MB
=
8
;
constexpr
size_t
KB
=
8
;
return
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
MK8
&&
kern_size_param
.
B_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
C_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
A_type
==
dtype
::
Float16
()
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
M
%
MB
==
0
&&
kern_size_param
.
K
%
KB
==
0
;
}
size_t
MatrixMulImpl
::
AlgoF16GiMK8_8x8
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_fb_gi_matmul_kern
,
midout_iv
(
"AlgoF16GiMK8_8x8::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
matmul
::
fallback
::
gi_sgemm_nopack_mk8_8x8_fp16
strategy
(
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
fallback
::
gi_sgemm_nopack_mk8_8x8_fp16
,
false
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
return
0
;
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF16GiMK8_8x8
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
gi_f16_mk8_8x8_kern
;
}
#endif
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */
/* ===================== F32 algo gi mk4 pack K4x12 ===================== */
namespace
{
namespace
{
void
f32_gi_mk4_pack_4x12_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
void
f32_gi_mk4_pack_4x12_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
dnn/src/fallback/matrix_mul/algos.h
浏览文件 @
a450d0f5
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <type_traits>
#include <type_traits>
#include "src/common/algo_base.h"
#include "src/common/algo_base.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"
...
@@ -97,6 +98,20 @@ public:
...
@@ -97,6 +98,20 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
FB_GI_F32_MK4_4x8
)
MEGDNN_DECL_ALGO_TYPE
(
FB_GI_F32_MK4_4x8
)
};
};
#if defined(GI_SUPPORT_F16)
class
MatrixMulImpl
::
AlgoF16GiMK8_8x8
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
const
char
*
name
()
const
override
{
return
"FB_GI_F16_MK8_8x8"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
8
,
AlgoDataType
::
FLOAT16
,
MK8
)
MEGDNN_DECL_ALGO_TYPE
(
FB_GI_F16_MK8_8x8
)
};
#endif
class
MatrixMulImpl
::
AlgoF32GiMK4Pack4x12
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoF32GiMK4Pack4x12
final
:
public
AlgoBase
{
public:
public:
AlgoAttribute
attribute
()
const
override
{
AlgoAttribute
attribute
()
const
override
{
...
...
dnn/src/fallback/matrix_mul/generic_strategy.h
浏览文件 @
a450d0f5
#pragma once
#pragma once
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace
megdnn
{
namespace
megdnn
{
...
@@ -8,6 +9,11 @@ namespace fallback {
...
@@ -8,6 +9,11 @@ namespace fallback {
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
8
,
12
,
1
,
false
,
true
,
sgemm_8x12
);
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
8
,
12
,
1
,
false
,
true
,
sgemm_8x12
);
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
4
,
8
,
1
,
false
,
true
,
gi_sgemm_nopack_4x8
);
float
,
float
,
float
,
4
,
8
,
1
,
false
,
true
,
gi_sgemm_nopack_4x8
);
#if defined(GI_SUPPORT_F16)
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
dt_float16
,
dt_float16
,
dt_float16
,
8
,
8
,
1
,
false
,
true
,
gi_sgemm_nopack_mk8_8x8_fp16
);
#endif
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
true
,
gi_sgemm_4x12
);
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
true
,
gi_sgemm_4x12
);
MEGDNN_REG_GEMM_STRATEGY
(
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
false
,
gi_sgemm_mk4_pack_4x12
);
float
,
float
,
float
,
4
,
12
,
1
,
false
,
false
,
gi_sgemm_mk4_pack_4x12
);
...
...
dnn/src/fallback/matrix_mul/gi/fp16/strategy_mk8_8x8.cpp
0 → 100644
浏览文件 @
a450d0f5
#include "src/fallback/general_intrinsic/gi_float16.h"
#if defined(GI_SUPPORT_F16)
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
using
namespace
megdnn
;
using
namespace
matmul
::
fallback
;
namespace
{
#define MLA GiMultiplyAddScalarFloat16
void
kern_8x1
(
const
gi_float16_t
*
A
,
const
gi_float16_t
*
B
,
size_t
LDB
,
size_t
K
,
gi_float16_t
*
C
)
{
LDB
=
LDB
-
8
;
K
=
K
-
8
;
GI_FLOAT16_t
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
vfzero
=
GiBroadcastFloat16
(
0.0
);
GI_FLOAT16_t
d8
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
+=
8
;
B
+=
LDB
;
for
(;
K
>
0
;
K
-=
8
)
{
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d8
=
MLA
(
d8
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
+=
8
;
B
+=
LDB
;
}
GiStoreFloat16
(
C
,
d8
);
}
void
kern_8x4
(
const
gi_float16_t
*
A
,
const
gi_float16_t
*
B
,
size_t
LDB
,
size_t
K
,
gi_float16_t
*
C
)
{
LDB
=
LDB
-
32
;
K
=
K
-
8
;
GI_FLOAT16_t
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
vfzero
=
GiBroadcastFloat16
(
0.0
);
GI_FLOAT16_t
d8
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
+=
8
;
GI_FLOAT16_t
d9
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d9
=
MLA
(
d9
,
d1
,
*
(
B
+
1
));
d9
=
MLA
(
d9
,
d2
,
*
(
B
+
2
));
d9
=
MLA
(
d9
,
d3
,
*
(
B
+
3
));
d9
=
MLA
(
d9
,
d4
,
*
(
B
+
4
));
d9
=
MLA
(
d9
,
d5
,
*
(
B
+
5
));
d9
=
MLA
(
d9
,
d6
,
*
(
B
+
6
));
d9
=
MLA
(
d9
,
d7
,
*
(
B
+
7
));
B
+=
8
;
GI_FLOAT16_t
d10
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d10
=
MLA
(
d10
,
d1
,
*
(
B
+
1
));
d10
=
MLA
(
d10
,
d2
,
*
(
B
+
2
));
d10
=
MLA
(
d10
,
d3
,
*
(
B
+
3
));
d10
=
MLA
(
d10
,
d4
,
*
(
B
+
4
));
d10
=
MLA
(
d10
,
d5
,
*
(
B
+
5
));
d10
=
MLA
(
d10
,
d6
,
*
(
B
+
6
));
d10
=
MLA
(
d10
,
d7
,
*
(
B
+
7
));
B
+=
8
;
GI_FLOAT16_t
d11
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d11
=
MLA
(
d11
,
d1
,
*
(
B
+
1
));
d11
=
MLA
(
d11
,
d2
,
*
(
B
+
2
));
d11
=
MLA
(
d11
,
d3
,
*
(
B
+
3
));
d11
=
MLA
(
d11
,
d4
,
*
(
B
+
4
));
d11
=
MLA
(
d11
,
d5
,
*
(
B
+
5
));
d11
=
MLA
(
d11
,
d6
,
*
(
B
+
6
));
d11
=
MLA
(
d11
,
d7
,
*
(
B
+
7
));
B
+=
8
;
B
+=
LDB
;
for
(;
K
>
0
;
K
-=
8
)
{
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d8
=
MLA
(
d8
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
+=
8
;
d9
=
MLA
(
d9
,
d0
,
*
(
B
));
d9
=
MLA
(
d9
,
d1
,
*
(
B
+
1
));
d9
=
MLA
(
d9
,
d2
,
*
(
B
+
2
));
d9
=
MLA
(
d9
,
d3
,
*
(
B
+
3
));
d9
=
MLA
(
d9
,
d4
,
*
(
B
+
4
));
d9
=
MLA
(
d9
,
d5
,
*
(
B
+
5
));
d9
=
MLA
(
d9
,
d6
,
*
(
B
+
6
));
d9
=
MLA
(
d9
,
d7
,
*
(
B
+
7
));
B
+=
8
;
d10
=
MLA
(
d10
,
d0
,
*
(
B
));
d10
=
MLA
(
d10
,
d1
,
*
(
B
+
1
));
d10
=
MLA
(
d10
,
d2
,
*
(
B
+
2
));
d10
=
MLA
(
d10
,
d3
,
*
(
B
+
3
));
d10
=
MLA
(
d10
,
d4
,
*
(
B
+
4
));
d10
=
MLA
(
d10
,
d5
,
*
(
B
+
5
));
d10
=
MLA
(
d10
,
d6
,
*
(
B
+
6
));
d10
=
MLA
(
d10
,
d7
,
*
(
B
+
7
));
B
+=
8
;
d11
=
MLA
(
d11
,
d0
,
*
(
B
));
d11
=
MLA
(
d11
,
d1
,
*
(
B
+
1
));
d11
=
MLA
(
d11
,
d2
,
*
(
B
+
2
));
d11
=
MLA
(
d11
,
d3
,
*
(
B
+
3
));
d11
=
MLA
(
d11
,
d4
,
*
(
B
+
4
));
d11
=
MLA
(
d11
,
d5
,
*
(
B
+
5
));
d11
=
MLA
(
d11
,
d6
,
*
(
B
+
6
));
d11
=
MLA
(
d11
,
d7
,
*
(
B
+
7
));
B
+=
8
;
B
+=
LDB
;
}
GiStoreFloat16
(
C
,
d8
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d9
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d10
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d11
);
C
=
C
+
8
;
}
void
kern_8x8
(
const
gi_float16_t
*
A
,
const
gi_float16_t
*
B
,
size_t
LDB
,
size_t
K
,
gi_float16_t
*
C
)
{
LDB
-=
64
;
K
=
K
-
8
;
GI_FLOAT16_t
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
GI_FLOAT16_t
vfzero
=
GiZeroFloat16
();
GI_FLOAT16_t
d8
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d9
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d9
=
MLA
(
d9
,
d1
,
*
(
B
+
1
));
d9
=
MLA
(
d9
,
d2
,
*
(
B
+
2
));
d9
=
MLA
(
d9
,
d3
,
*
(
B
+
3
));
d9
=
MLA
(
d9
,
d4
,
*
(
B
+
4
));
d9
=
MLA
(
d9
,
d5
,
*
(
B
+
5
));
d9
=
MLA
(
d9
,
d6
,
*
(
B
+
6
));
d9
=
MLA
(
d9
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d10
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d10
=
MLA
(
d10
,
d1
,
*
(
B
+
1
));
d10
=
MLA
(
d10
,
d2
,
*
(
B
+
2
));
d10
=
MLA
(
d10
,
d3
,
*
(
B
+
3
));
d10
=
MLA
(
d10
,
d4
,
*
(
B
+
4
));
d10
=
MLA
(
d10
,
d5
,
*
(
B
+
5
));
d10
=
MLA
(
d10
,
d6
,
*
(
B
+
6
));
d10
=
MLA
(
d10
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d11
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d11
=
MLA
(
d11
,
d1
,
*
(
B
+
1
));
d11
=
MLA
(
d11
,
d2
,
*
(
B
+
2
));
d11
=
MLA
(
d11
,
d3
,
*
(
B
+
3
));
d11
=
MLA
(
d11
,
d4
,
*
(
B
+
4
));
d11
=
MLA
(
d11
,
d5
,
*
(
B
+
5
));
d11
=
MLA
(
d11
,
d6
,
*
(
B
+
6
));
d11
=
MLA
(
d11
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d12
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d12
=
MLA
(
d12
,
d1
,
*
(
B
+
1
));
d12
=
MLA
(
d12
,
d2
,
*
(
B
+
2
));
d12
=
MLA
(
d12
,
d3
,
*
(
B
+
3
));
d12
=
MLA
(
d12
,
d4
,
*
(
B
+
4
));
d12
=
MLA
(
d12
,
d5
,
*
(
B
+
5
));
d12
=
MLA
(
d12
,
d6
,
*
(
B
+
6
));
d12
=
MLA
(
d12
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d13
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d13
=
MLA
(
d13
,
d1
,
*
(
B
+
1
));
d13
=
MLA
(
d13
,
d2
,
*
(
B
+
2
));
d13
=
MLA
(
d13
,
d3
,
*
(
B
+
3
));
d13
=
MLA
(
d13
,
d4
,
*
(
B
+
4
));
d13
=
MLA
(
d13
,
d5
,
*
(
B
+
5
));
d13
=
MLA
(
d13
,
d6
,
*
(
B
+
6
));
d13
=
MLA
(
d13
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d14
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d14
=
MLA
(
d14
,
d1
,
*
(
B
+
1
));
d14
=
MLA
(
d14
,
d2
,
*
(
B
+
2
));
d14
=
MLA
(
d14
,
d3
,
*
(
B
+
3
));
d14
=
MLA
(
d14
,
d4
,
*
(
B
+
4
));
d14
=
MLA
(
d14
,
d5
,
*
(
B
+
5
));
d14
=
MLA
(
d14
,
d6
,
*
(
B
+
6
));
d14
=
MLA
(
d14
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
GI_FLOAT16_t
d15
=
MLA
(
vfzero
,
d0
,
*
(
B
));
d15
=
MLA
(
d15
,
d1
,
*
(
B
+
1
));
d15
=
MLA
(
d15
,
d2
,
*
(
B
+
2
));
d15
=
MLA
(
d15
,
d3
,
*
(
B
+
3
));
d15
=
MLA
(
d15
,
d4
,
*
(
B
+
4
));
d15
=
MLA
(
d15
,
d5
,
*
(
B
+
5
));
d15
=
MLA
(
d15
,
d6
,
*
(
B
+
6
));
d15
=
MLA
(
d15
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
B
=
B
+
LDB
;
for
(;
K
>
0
;
K
-=
8
)
{
d0
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d1
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d2
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d3
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d4
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d5
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d6
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d7
=
GiLoadFloat16
(
A
);
A
=
A
+
8
;
d8
=
MLA
(
d8
,
d0
,
*
(
B
));
d8
=
MLA
(
d8
,
d1
,
*
(
B
+
1
));
d8
=
MLA
(
d8
,
d2
,
*
(
B
+
2
));
d8
=
MLA
(
d8
,
d3
,
*
(
B
+
3
));
d8
=
MLA
(
d8
,
d4
,
*
(
B
+
4
));
d8
=
MLA
(
d8
,
d5
,
*
(
B
+
5
));
d8
=
MLA
(
d8
,
d6
,
*
(
B
+
6
));
d8
=
MLA
(
d8
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d9
=
MLA
(
d9
,
d0
,
*
(
B
));
d9
=
MLA
(
d9
,
d1
,
*
(
B
+
1
));
d9
=
MLA
(
d9
,
d2
,
*
(
B
+
2
));
d9
=
MLA
(
d9
,
d3
,
*
(
B
+
3
));
d9
=
MLA
(
d9
,
d4
,
*
(
B
+
4
));
d9
=
MLA
(
d9
,
d5
,
*
(
B
+
5
));
d9
=
MLA
(
d9
,
d6
,
*
(
B
+
6
));
d9
=
MLA
(
d9
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d10
=
MLA
(
d10
,
d0
,
*
(
B
));
d10
=
MLA
(
d10
,
d1
,
*
(
B
+
1
));
d10
=
MLA
(
d10
,
d2
,
*
(
B
+
2
));
d10
=
MLA
(
d10
,
d3
,
*
(
B
+
3
));
d10
=
MLA
(
d10
,
d4
,
*
(
B
+
4
));
d10
=
MLA
(
d10
,
d5
,
*
(
B
+
5
));
d10
=
MLA
(
d10
,
d6
,
*
(
B
+
6
));
d10
=
MLA
(
d10
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d11
=
MLA
(
d11
,
d0
,
*
(
B
));
d11
=
MLA
(
d11
,
d1
,
*
(
B
+
1
));
d11
=
MLA
(
d11
,
d2
,
*
(
B
+
2
));
d11
=
MLA
(
d11
,
d3
,
*
(
B
+
3
));
d11
=
MLA
(
d11
,
d4
,
*
(
B
+
4
));
d11
=
MLA
(
d11
,
d5
,
*
(
B
+
5
));
d11
=
MLA
(
d11
,
d6
,
*
(
B
+
6
));
d11
=
MLA
(
d11
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d12
=
MLA
(
d12
,
d0
,
*
(
B
));
d12
=
MLA
(
d12
,
d1
,
*
(
B
+
1
));
d12
=
MLA
(
d12
,
d2
,
*
(
B
+
2
));
d12
=
MLA
(
d12
,
d3
,
*
(
B
+
3
));
d12
=
MLA
(
d12
,
d4
,
*
(
B
+
4
));
d12
=
MLA
(
d12
,
d5
,
*
(
B
+
5
));
d12
=
MLA
(
d12
,
d6
,
*
(
B
+
6
));
d12
=
MLA
(
d12
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d13
=
MLA
(
d13
,
d0
,
*
(
B
));
d13
=
MLA
(
d13
,
d1
,
*
(
B
+
1
));
d13
=
MLA
(
d13
,
d2
,
*
(
B
+
2
));
d13
=
MLA
(
d13
,
d3
,
*
(
B
+
3
));
d13
=
MLA
(
d13
,
d4
,
*
(
B
+
4
));
d13
=
MLA
(
d13
,
d5
,
*
(
B
+
5
));
d13
=
MLA
(
d13
,
d6
,
*
(
B
+
6
));
d13
=
MLA
(
d13
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d14
=
MLA
(
d14
,
d0
,
*
(
B
));
d14
=
MLA
(
d14
,
d1
,
*
(
B
+
1
));
d14
=
MLA
(
d14
,
d2
,
*
(
B
+
2
));
d14
=
MLA
(
d14
,
d3
,
*
(
B
+
3
));
d14
=
MLA
(
d14
,
d4
,
*
(
B
+
4
));
d14
=
MLA
(
d14
,
d5
,
*
(
B
+
5
));
d14
=
MLA
(
d14
,
d6
,
*
(
B
+
6
));
d14
=
MLA
(
d14
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
;
d15
=
MLA
(
d15
,
d0
,
*
(
B
));
d15
=
MLA
(
d15
,
d1
,
*
(
B
+
1
));
d15
=
MLA
(
d15
,
d2
,
*
(
B
+
2
));
d15
=
MLA
(
d15
,
d3
,
*
(
B
+
3
));
d15
=
MLA
(
d15
,
d4
,
*
(
B
+
4
));
d15
=
MLA
(
d15
,
d5
,
*
(
B
+
5
));
d15
=
MLA
(
d15
,
d6
,
*
(
B
+
6
));
d15
=
MLA
(
d15
,
d7
,
*
(
B
+
7
));
B
=
B
+
8
+
LDB
;
}
GiStoreFloat16
(
C
,
d8
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d9
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d10
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d11
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d12
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d13
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d14
);
C
=
C
+
8
;
GiStoreFloat16
(
C
,
d15
);
C
=
C
+
8
;
}
#undef MLA
}
// namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK
(
gi_sgemm_nopack_mk8_8x8_fp16
);
void
gi_sgemm_nopack_mk8_8x8_fp16
::
kern
(
const
dt_float16
*
A
,
size_t
LDA
,
const
dt_float16
*
B
,
size_t
LDB
,
dt_float16
*
C
,
size_t
LDC
,
size_t
M
,
size_t
K
,
size_t
N
,
const
dt_float16
*
,
void
*
,
bool
trA
,
bool
trB
)
const
{
constexpr
size_t
MB
=
8
;
constexpr
size_t
KB
=
8
;
constexpr
size_t
NB
=
8
;
constexpr
size_t
NB_HALF
=
4
;
megdnn_assert
(
!
trA
&&
!
trB
&&
M
%
MB
==
0
&&
K
%
KB
==
0
);
for
(
size_t
m
=
0
;
m
<
M
;
m
+=
MB
)
{
gi_float16_t
*
output
=
reinterpret_cast
<
gi_float16_t
*>
(
C
)
+
(
m
/
MB
)
*
LDC
;
const
gi_float16_t
*
cur_B
=
reinterpret_cast
<
const
gi_float16_t
*>
(
B
);
size_t
n
=
0
;
for
(;
n
+
NB
-
1
<
N
;
n
+=
NB
)
{
kern_8x8
(
reinterpret_cast
<
const
gi_float16_t
*>
(
A
),
cur_B
,
LDB
,
K
,
output
);
cur_B
+=
KB
*
NB
;
output
+=
MB
*
NB
;
}
if
(
N
-
n
>=
4
)
{
kern_8x4
(
reinterpret_cast
<
const
gi_float16_t
*>
(
A
),
cur_B
,
LDB
,
K
,
output
);
cur_B
+=
KB
*
NB_HALF
;
output
+=
MB
*
NB_HALF
;
n
+=
4
;
}
while
(
n
<
N
)
{
kern_8x1
(
reinterpret_cast
<
const
gi_float16_t
*>
(
A
),
cur_B
,
LDB
,
K
,
output
);
cur_B
+=
KB
;
output
+=
MB
;
n
++
;
}
A
+=
LDA
;
}
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/fallback/matrix_mul/opr_impl.cpp
浏览文件 @
a450d0f5
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "src/common/algo_chooser.h"
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "src/fallback/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "src/fallback/matrix_mul/generic_strategy.h"
...
@@ -30,6 +31,9 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
...
@@ -30,6 +31,9 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32GiMK4_4x8
f32_mk4_4x8
;
AlgoF32GiMK4_4x8
f32_mk4_4x8
;
AlgoF32GiMK4Pack4x12
f32_mk4_gi_pack_4x12
;
AlgoF32GiMK4Pack4x12
f32_mk4_gi_pack_4x12
;
AlgoF32Gi4x12
f32_4x8
;
AlgoF32Gi4x12
f32_4x8
;
#if defined(GI_SUPPORT_F16)
AlgoF16GiMK8_8x8
f16_mk8_8x8
;
#endif
SmallVector
<
AlgoBase
*>
m_all_algos
;
SmallVector
<
AlgoBase
*>
m_all_algos
;
AlgoBase
::
Mapper
m_all_algos_map
;
AlgoBase
::
Mapper
m_all_algos_map
;
...
@@ -39,6 +43,9 @@ public:
...
@@ -39,6 +43,9 @@ public:
m_all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
m_all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
m_all_algos
.
emplace_back
(
&
f32_mk4_gi_pack_4x12
);
m_all_algos
.
emplace_back
(
&
f32_mk4_gi_pack_4x12
);
m_all_algos
.
emplace_back
(
&
f32_4x8
);
m_all_algos
.
emplace_back
(
&
f32_4x8
);
#if defined(GI_SUPPORT_F16)
m_all_algos
.
emplace_back
(
&
f16_mk8_8x8
);
#endif
m_all_algos
.
emplace_back
(
&
gemv
);
m_all_algos
.
emplace_back
(
&
gemv
);
m_all_algos
.
emplace_back
(
&
f32_k8x12x1
);
m_all_algos
.
emplace_back
(
&
f32_k8x12x1
);
m_all_algos
.
emplace_back
(
&
naive
);
m_all_algos
.
emplace_back
(
&
naive
);
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
a450d0f5
...
@@ -103,6 +103,7 @@ public:
...
@@ -103,6 +103,7 @@ public:
FB_NAIVE
,
FB_NAIVE
,
FB_GI_F32_GEMV_MK4
,
FB_GI_F32_GEMV_MK4
,
FB_GI_F32_MK4_4x8
,
FB_GI_F32_MK4_4x8
,
FB_GI_F16_MK8_8x8
,
FB_GI_F32_MK4_PACK_4x12
,
FB_GI_F32_MK4_PACK_4x12
,
FB_GI_F32_4x12
,
FB_GI_F32_4x12
,
...
@@ -237,6 +238,7 @@ private:
...
@@ -237,6 +238,7 @@ private:
class
AlgoF32GiMK4_4x8
;
// fallback F32 gi Gemm NCHW44
class
AlgoF32GiMK4_4x8
;
// fallback F32 gi Gemm NCHW44
class
AlgoF32GiMK4Pack4x12
;
// fallback F32 gi Gemm pack NCHW44
class
AlgoF32GiMK4Pack4x12
;
// fallback F32 gi Gemm pack NCHW44
class
AlgoF32Gi4x12
;
// fallback F32 gi Gemm
class
AlgoF32Gi4x12
;
// fallback F32 gi Gemm
class
AlgoF16GiMK8_8x8
;
class
AlgoGemv
;
class
AlgoGemv
;
class
AlgoNaive
;
class
AlgoNaive
;
class
AlgoPack
;
class
AlgoPack
;
...
...
dnn/test/fallback/matrix_mul.cpp
浏览文件 @
a450d0f5
#include "test/common/matrix_mul.h"
#include "test/common/matrix_mul.h"
#include "src/fallback/general_intrinsic/gi_common.h"
#include "test/common/checker.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/common/rng.h"
#include "test/common/task_record_check.h"
#include "test/common/task_record_check.h"
...
@@ -42,6 +43,14 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
...
@@ -42,6 +43,14 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
"FB_GI_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
}
}
#if defined(GI_SUPPORT_F16)
TEST_F
(
FALLBACK
,
MATRIX_MUL_FP16_MK8_GI
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
handle
(),
"FB_GI_F16_MK8_8x8"
,
param
::
MatrixMul
::
Format
::
MK8
,
1
);
}
#endif
TEST_F
(
FALLBACK
,
MATRIX_MUL_GI_F32_4x12
)
{
TEST_F
(
FALLBACK
,
MATRIX_MUL_GI_F32_4x12
)
{
matrix_mul
::
check_matrix_mul
(
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
...
@@ -183,6 +192,15 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) {
...
@@ -183,6 +192,15 @@ TEST_F(FALLBACK, BENCHMARK_MATRIX_FB_GI_F32_MK4_4x8) {
"FB_GI_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
);
"FB_GI_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
);
}
}
#if defined(GI_SUPPORT_F16)
TEST_F
(
FALLBACK
,
BENCHMARK_MATRIX_FB_GI_F16_MK8_8x8
)
{
auto
args
=
matrix_mul
::
get_benchmark_matmul_args
();
matrix_mul
::
benchmark_single_algo
(
handle
(),
args
,
dtype
::
Float16
{},
dtype
::
Float16
{},
dtype
::
Float16
{},
"FB_GI_F16_MK8_8x8"
,
param
::
MatrixMul
::
Format
::
MK8
);
}
#endif
#endif
#endif
}
// namespace test
}
// namespace test
}
// namespace megdnn
}
// namespace megdnn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录