Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a54d9cb9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a54d9cb9
编写于
7月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(x86/rvv): opt FB_GI_F32_MK4_PACK_4x12 algo
GitOrigin-RevId: a80805c119c2d572d9ea6447a3a32d0c2e2063fc
上级
d60d028a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
110 addition
and
117 deletion
+110
-117
dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp
dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp
+110
-117
未找到文件。
dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk_4x12.cpp
浏览文件 @
a54d9cb9
...
@@ -20,6 +20,9 @@ using namespace matmul::fallback;
...
@@ -20,6 +20,9 @@ using namespace matmul::fallback;
namespace
{
namespace
{
//! x86 and rvv GiSimdFmaLane API is slowly, as an alternate, use
//! GiMultiplyAddScalarFloat32
#define MLA GiMultiplyAddScalarFloat32
void
kern_4x12
(
void
kern_4x12
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
)
{
bool
is_first_k
)
{
...
@@ -32,24 +35,18 @@ void kern_4x12(
...
@@ -32,24 +35,18 @@ void kern_4x12(
K
=
((
K
+
1
)
/
2
)
-
1
;
K
=
((
K
+
1
)
/
2
)
-
1
;
float
*
r1
=
output
;
float
*
r1
=
output
;
GI_FLOAT32_t
d0d1
,
d2d3
,
d
4d5
,
d6d7
,
d8d9
,
d10d11
,
d12d13
,
d14d15
,
d16d17
,
d18d19
,
GI_FLOAT32_t
d0d1
,
d2d3
,
d
8d9
,
d10d11
,
d12d13
,
d14d15
,
d16d17
,
d18d19
,
d20d21
,
d2
0d21
,
d2
2d23
,
d24d25
,
d26d27
,
d28d29
,
d30d31
;
d22d23
,
d24d25
,
d26d27
,
d28d29
,
d30d31
;
if
(
is_first_k
)
{
if
(
is_first_k
)
{
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
d10d11
=
GiBroadcastFloat32
(
0.0
f
);
d10d11
=
GiBroadcastFloat32
(
0.0
f
);
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d16d17
=
GiBroadcastFloat32
(
0.0
f
);
d16d17
=
GiBroadcastFloat32
(
0.0
f
);
d18d19
=
GiBroadcastFloat32
(
0.0
f
);
d18d19
=
GiBroadcastFloat32
(
0.0
f
);
d20d21
=
GiBroadcastFloat32
(
0.0
f
);
d20d21
=
GiBroadcastFloat32
(
0.0
f
);
d22d23
=
GiBroadcastFloat32
(
0.0
f
);
d22d23
=
GiBroadcastFloat32
(
0.0
f
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d24d25
=
GiBroadcastFloat32
(
0.0
f
);
d24d25
=
GiBroadcastFloat32
(
0.0
f
);
d26d27
=
GiBroadcastFloat32
(
0.0
f
);
d26d27
=
GiBroadcastFloat32
(
0.0
f
);
d28d29
=
GiBroadcastFloat32
(
0.0
f
);
d28d29
=
GiBroadcastFloat32
(
0.0
f
);
...
@@ -84,145 +81,144 @@ void kern_4x12(
...
@@ -84,145 +81,144 @@ void kern_4x12(
r1
=
r1
+
4
;
r1
=
r1
+
4
;
d30d31
=
GiLoadFloat32
(
r1
);
d30d31
=
GiLoadFloat32
(
r1
);
r1
=
r1
+
4
;
r1
=
r1
+
4
;
}
for
(;
K
>
0
;
K
--
)
{
d0d1
=
GiLoadFloat32
(
a_ptr
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
));
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
}
for
(;
K
>
0
;
K
--
)
{
d16d17
=
MLA
(
d16d17
,
d0d1
,
*
(
b_ptr
));
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d18d19
=
MLA
(
d18d19
,
d0d1
,
*
(
b_ptr
+
1
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d20d21
=
MLA
(
d20d21
,
d0d1
,
*
(
b_ptr
+
2
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d22d23
=
MLA
(
d22d23
,
d0d1
,
*
(
b_ptr
+
3
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d16d17
=
GiSimdFmaLane
(
d16d17
,
d0d1
,
d6d7
,
0
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d0d1
,
d6d7
,
1
);
d24d25
=
MLA
(
d24d25
,
d0d1
,
*
(
b_ptr
));
d20d21
=
GiSimdFmaLane
(
d20d21
,
d0d1
,
d6d7
,
2
);
d26d27
=
MLA
(
d26d27
,
d0d1
,
*
(
b_ptr
+
1
));
d28d29
=
MLA
(
d28d29
,
d0d1
,
*
(
b_ptr
+
2
));
d30d31
=
MLA
(
d30d31
,
d0d1
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
d2d3
=
GiLoadFloat32
(
a_ptr
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d22d23
=
GiSimdFmaLane
(
d22d23
,
d0d1
,
d6d7
,
3
);
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d24d25
=
GiSimdFmaLane
(
d24d25
,
d0d1
,
d4d5
,
0
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d0d1
,
d4d5
,
1
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d0d1
,
d4d5
,
2
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d0d1
,
d4d5
,
3
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d6d7
,
0
);
d8d9
=
MLA
(
d8d9
,
d2d3
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d2d3
,
d6d7
,
1
);
d10d11
=
MLA
(
d10d11
,
d2d3
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d2d3
,
d6d7
,
2
);
d12d13
=
MLA
(
d12d13
,
d2d3
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d6d7
,
3
);
d14d15
=
MLA
(
d14d15
,
d2d3
,
*
(
b_ptr
+
3
));
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d16d17
=
GiSimdFmaLane
(
d16d17
,
d2d3
,
d4d5
,
0
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d2d3
,
d4d5
,
1
);
d16d17
=
MLA
(
d16d17
,
d2d3
,
*
(
b_ptr
));
d0d1
=
GiLoadFloat32
(
a_ptr
);
d18d19
=
MLA
(
d18d19
,
d2d3
,
*
(
b_ptr
+
1
));
a_ptr
=
a_ptr
+
4
;
d20d21
=
MLA
(
d20d21
,
d2d3
,
*
(
b_ptr
+
2
));
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d4d5
,
2
);
d22d23
=
MLA
(
d22d23
,
d2d3
,
*
(
b_ptr
+
3
));
d22d23
=
GiSimdFmaLane
(
d22d23
,
d2d3
,
d4d5
,
3
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d24d25
=
GiSimdFmaLane
(
d24d25
,
d2d3
,
d6d7
,
0
);
d2
6d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d6d7
,
1
);
d2
4d25
=
MLA
(
d24d25
,
d2d3
,
*
(
b_ptr
)
);
d2
8d29
=
GiSimdFmaLane
(
d28d29
,
d2d3
,
d6d7
,
2
);
d2
6d27
=
MLA
(
d26d27
,
d2d3
,
*
(
b_ptr
+
1
)
);
d
30d31
=
GiSimdFmaLane
(
d30d31
,
d2d3
,
d6d7
,
3
);
d
28d29
=
MLA
(
d28d29
,
d2d3
,
*
(
b_ptr
+
2
)
);
d
6d7
=
GiLoadFloat32
(
b_ptr
);
d
30d31
=
MLA
(
d30d31
,
d2d3
,
*
(
b_ptr
+
3
)
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
}
}
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
if
(
1
==
oddk
)
{
if
(
1
==
oddk
)
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d16d17
=
GiSimdFmaLane
(
d16d17
,
d0d1
,
d6d7
,
0
);
d16d17
=
MLA
(
d16d17
,
d0d1
,
*
(
b_ptr
));
GiStoreFloat32
(
output0
,
d8d9
);
GiStoreFloat32
(
output0
,
d8d9
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d10d11
);
GiStoreFloat32
(
output0
,
d10d11
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d18d19
=
GiSimdFmaLane
(
d18d19
,
d0d1
,
d6d7
,
1
);
d18d19
=
MLA
(
d18d19
,
d0d1
,
*
(
b_ptr
+
1
)
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d0d1
,
d6d7
,
2
);
d20d21
=
MLA
(
d20d21
,
d0d1
,
*
(
b_ptr
+
2
)
);
GiStoreFloat32
(
output0
,
d12d13
);
GiStoreFloat32
(
output0
,
d12d13
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d14d15
);
GiStoreFloat32
(
output0
,
d14d15
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d22d23
=
GiSimdFmaLane
(
d22d23
,
d0d1
,
d6d7
,
3
);
d22d23
=
MLA
(
d22d23
,
d0d1
,
*
(
b_ptr
+
3
));
d24d25
=
GiSimdFmaLane
(
d24d25
,
d0d1
,
d4d5
,
0
);
b_ptr
=
b_ptr
+
4
;
d24d25
=
MLA
(
d24d25
,
d0d1
,
*
(
b_ptr
));
GiStoreFloat32
(
output0
,
d16d17
);
GiStoreFloat32
(
output0
,
d16d17
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d18d19
);
GiStoreFloat32
(
output0
,
d18d19
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d26d27
=
GiSimdFmaLane
(
d26d27
,
d0d1
,
d4d5
,
1
);
d26d27
=
MLA
(
d26d27
,
d0d1
,
*
(
b_ptr
+
1
)
);
GiStoreFloat32
(
output0
,
d20d21
);
GiStoreFloat32
(
output0
,
d20d21
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d22d23
);
GiStoreFloat32
(
output0
,
d22d23
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d28d29
=
GiSimdFmaLane
(
d28d29
,
d0d1
,
d4d5
,
2
);
d28d29
=
MLA
(
d28d29
,
d0d1
,
*
(
b_ptr
+
2
)
);
GiStoreFloat32
(
output0
,
d24d25
);
GiStoreFloat32
(
output0
,
d24d25
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d26d27
);
GiStoreFloat32
(
output0
,
d26d27
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d30d31
=
GiSimdFmaLane
(
d30d31
,
d0d1
,
d4d5
,
3
);
d30d31
=
MLA
(
d30d31
,
d0d1
,
*
(
b_ptr
+
3
)
);
GiStoreFloat32
(
output0
,
d28d29
);
GiStoreFloat32
(
output0
,
d28d29
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d30d31
);
GiStoreFloat32
(
output0
,
d30d31
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
b_ptr
=
b_ptr
+
4
;
}
else
{
}
else
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d16d17
=
GiSimdFmaLane
(
d16d17
,
d0d1
,
d6d7
,
0
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d0d1
,
d6d7
,
1
);
d16d17
=
MLA
(
d16d17
,
d0d1
,
*
(
b_ptr
));
d20d21
=
GiSimdFmaLane
(
d20d21
,
d0d1
,
d6d7
,
2
);
d18d19
=
MLA
(
d18d19
,
d0d1
,
*
(
b_ptr
+
1
));
d20d21
=
MLA
(
d20d21
,
d0d1
,
*
(
b_ptr
+
2
));
d2d3
=
GiLoadFloat32
(
a_ptr
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d22d23
=
GiSimdFmaLane
(
d22d23
,
d0d1
,
d6d7
,
3
);
d22d23
=
MLA
(
d22d23
,
d0d1
,
*
(
b_ptr
+
3
));
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d24d25
=
GiSimdFmaLane
(
d24d25
,
d0d1
,
d4d5
,
0
);
d2
6d27
=
GiSimdFmaLane
(
d26d27
,
d0d1
,
d4d5
,
1
);
d2
4d25
=
MLA
(
d24d25
,
d0d1
,
*
(
b_ptr
)
);
d2
8d29
=
GiSimdFmaLane
(
d28d29
,
d0d1
,
d4d5
,
2
);
d2
6d27
=
MLA
(
d26d27
,
d0d1
,
*
(
b_ptr
+
1
)
);
d
30d31
=
GiSimdFmaLane
(
d30d31
,
d0d1
,
d4d5
,
3
);
d
28d29
=
MLA
(
d28d29
,
d0d1
,
*
(
b_ptr
+
2
)
);
d
4d5
=
GiLoadFloat32
(
b_ptr
);
d
30d31
=
MLA
(
d30d31
,
d0d1
,
*
(
b_ptr
+
3
)
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d6d7
,
0
);
d8d9
=
MLA
(
d8d9
,
d2d3
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d2d3
,
d6d7
,
1
);
d10d11
=
MLA
(
d10d11
,
d2d3
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d2d3
,
d6d7
,
2
);
d12d13
=
MLA
(
d12d13
,
d2d3
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d6d7
,
3
);
d14d15
=
MLA
(
d14d15
,
d2d3
,
*
(
b_ptr
+
3
));
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d16d17
=
GiSimdFmaLane
(
d16d17
,
d2d3
,
d4d5
,
0
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d2d3
,
d4d5
,
1
);
d16d17
=
MLA
(
d16d17
,
d2d3
,
*
(
b_ptr
));
d18d19
=
MLA
(
d18d19
,
d2d3
,
*
(
b_ptr
+
1
));
GiStoreFloat32
(
output0
,
d8d9
);
GiStoreFloat32
(
output0
,
d8d9
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d10d11
);
GiStoreFloat32
(
output0
,
d10d11
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d4d5
,
2
);
d20d21
=
MLA
(
d20d21
,
d2d3
,
*
(
b_ptr
+
2
)
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d2d3
,
d4d5
,
3
);
d22d23
=
MLA
(
d22d23
,
d2d3
,
*
(
b_ptr
+
3
)
);
GiStoreFloat32
(
output0
,
d12d13
);
GiStoreFloat32
(
output0
,
d12d13
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d14d15
);
GiStoreFloat32
(
output0
,
d14d15
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d24d25
=
GiSimdFmaLane
(
d24d25
,
d2d3
,
d6d7
,
0
);
b_ptr
=
b_ptr
+
4
;
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d6d7
,
1
);
d24d25
=
MLA
(
d24d25
,
d2d3
,
*
(
b_ptr
));
d26d27
=
MLA
(
d26d27
,
d2d3
,
*
(
b_ptr
+
1
));
GiStoreFloat32
(
output0
,
d16d17
);
GiStoreFloat32
(
output0
,
d16d17
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d18d19
);
GiStoreFloat32
(
output0
,
d18d19
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
d28d29
=
GiSimdFmaLane
(
d28d29
,
d2d3
,
d6d7
,
2
);
d28d29
=
MLA
(
d28d29
,
d2d3
,
*
(
b_ptr
+
2
)
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d2d3
,
d6d7
,
3
);
d30d31
=
MLA
(
d30d31
,
d2d3
,
*
(
b_ptr
+
3
)
);
GiStoreFloat32
(
output0
,
d20d21
);
GiStoreFloat32
(
output0
,
d20d21
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d22d23
);
GiStoreFloat32
(
output0
,
d22d23
);
...
@@ -235,6 +231,7 @@ void kern_4x12(
...
@@ -235,6 +231,7 @@ void kern_4x12(
output0
=
output0
+
4
;
output0
=
output0
+
4
;
GiStoreFloat32
(
output0
,
d30d31
);
GiStoreFloat32
(
output0
,
d30d31
);
output0
=
output0
+
4
;
output0
=
output0
+
4
;
b_ptr
=
b_ptr
+
4
;
}
}
}
}
...
@@ -249,7 +246,7 @@ void kern_4x4(
...
@@ -249,7 +246,7 @@ void kern_4x4(
K
=
((
K
+
1
)
/
2
)
-
1
;
K
=
((
K
+
1
)
/
2
)
-
1
;
float
*
r1
=
output
;
float
*
r1
=
output
;
GI_FLOAT32_t
d0d1
,
d2d3
,
d
4d5
,
d6d7
,
d
8d9
,
d10d11
,
d12d13
,
d14d15
;
GI_FLOAT32_t
d0d1
,
d2d3
,
d8d9
,
d10d11
,
d12d13
,
d14d15
;
if
(
is_first_k
)
{
if
(
is_first_k
)
{
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
...
@@ -260,9 +257,6 @@ void kern_4x4(
...
@@ -260,9 +257,6 @@ void kern_4x4(
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
}
else
{
}
else
{
if
(
n_remain
==
4
)
{
if
(
n_remain
==
4
)
{
...
@@ -293,44 +287,43 @@ void kern_4x4(
...
@@ -293,44 +287,43 @@ void kern_4x4(
}
}
for
(;
K
>
0
;
K
--
)
{
for
(;
K
>
0
;
K
--
)
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
)
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d6d7
=
GiLoadFloat32
(
b_ptr
);
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d4d5
=
GiLoadFloat32
(
b_ptr
);
d8d9
=
MLA
(
d8d9
,
d2d3
,
*
(
b_ptr
));
b_ptr
=
b_ptr
+
4
;
d10d11
=
MLA
(
d10d11
,
d2d3
,
*
(
b_ptr
+
1
));
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d6d7
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d2d3
,
d6d7
,
1
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d12d13
=
GiSimdFmaLane
(
d12d13
,
d2d3
,
d6d7
,
2
);
d12d13
=
MLA
(
d12d13
,
d2d3
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d6d7
,
3
);
d14d15
=
MLA
(
d14d15
,
d2d3
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
}
}
if
(
1
==
oddk
)
{
if
(
1
==
oddk
)
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
}
else
{
}
else
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d0d1
,
d4d5
,
0
);
d8d9
=
MLA
(
d8d9
,
d0d1
,
*
(
b_ptr
)
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
a_ptr
=
a_ptr
+
4
;
d10d11
=
GiSimdFmaLane
(
d10d11
,
d0d1
,
d4d5
,
1
);
d10d11
=
MLA
(
d10d11
,
d0d1
,
*
(
b_ptr
+
1
));
d6d7
=
GiLoadFloat32
(
b_ptr
);
d12d13
=
MLA
(
d12d13
,
d0d1
,
*
(
b_ptr
+
2
));
d14d15
=
MLA
(
d14d15
,
d0d1
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
b_ptr
=
b_ptr
+
4
;
d12d13
=
GiSimdFmaLane
(
d12d13
,
d0d1
,
d4d5
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d0d1
,
d4d5
,
3
);
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d6d7
,
0
);
d8d9
=
MLA
(
d8d9
,
d2d3
,
*
(
b_ptr
));
d10d11
=
GiSimdFmaLane
(
d10d11
,
d2d3
,
d6d7
,
1
);
d10d11
=
MLA
(
d10d11
,
d2d3
,
*
(
b_ptr
+
1
));
d12d13
=
GiSimdFmaLane
(
d12d13
,
d2d3
,
d6d7
,
2
);
d12d13
=
MLA
(
d12d13
,
d2d3
,
*
(
b_ptr
+
2
));
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d6d7
,
3
);
d14d15
=
MLA
(
d14d15
,
d2d3
,
*
(
b_ptr
+
3
));
b_ptr
=
b_ptr
+
4
;
}
}
if
(
n_remain
==
4
)
{
if
(
n_remain
==
4
)
{
...
@@ -359,7 +352,7 @@ void kern_4x4(
...
@@ -359,7 +352,7 @@ void kern_4x4(
output
=
output
+
4
;
output
=
output
+
4
;
}
}
}
}
#undef MLA
}
// namespace
}
// namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gi_sgemm_mk4_pack_4x12
);
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gi_sgemm_mk4_pack_4x12
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录