Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0612b573
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0612b573
编写于
2月 15, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(dnn/arm): optimize gevm by reducing access to memory of matrix A
GitOrigin-RevId: 89ed7bfd50114be4fc8c2c3283bf92883afc5283
上级
b622064a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
224 addition
and
339 deletion
+224
-339
dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp
dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp
+224
-339
未找到文件。
dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp
浏览文件 @
0612b573
...
...
@@ -85,6 +85,18 @@ void hgemv_naive_n(
}
}
// namespace
#if defined(__aarch64__)
#define VFMAQ_N_F16(a, b, n) vfmaq_n_f16(a, b, n)
#else
#define VFMAQ_N_F16(a, b, n) vaddq_f16(a, vmulq_n_f16(b, n))
#endif
#if defined(__aarch64__)
#define VFMA_N_F16(a, b, n) vfma_n_f16(a, b, n)
#else
#define VFMA_N_F16(a, b, n) vadd_f16(a, vmul_n_f16(b, n))
#endif
void
megdnn
::
arm_common
::
gemv_like
(
const
__fp16
*
__restrict
A
,
const
__fp16
*
__restrict
B
,
__fp16
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
...
...
@@ -98,33 +110,30 @@ void megdnn::arm_common::gemv_like(
memset
(
C
+
m
*
Cstride
,
0
,
4
*
sizeof
(
__fp16
)
*
N
);
for
(;
k
+
4
<=
K
;
k
+=
4
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
],
a02
=
A
[
m
*
Astride
+
k
+
2
],
a03
=
A
[
m
*
Astride
+
k
+
3
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
],
a11
=
A
[(
m
+
1
)
*
Astride
+
k
+
1
],
a12
=
A
[(
m
+
1
)
*
Astride
+
k
+
2
],
a13
=
A
[(
m
+
1
)
*
Astride
+
k
+
3
];
__fp16
a20
=
A
[(
m
+
2
)
*
Astride
+
k
],
a21
=
A
[(
m
+
2
)
*
Astride
+
k
+
1
],
a22
=
A
[(
m
+
2
)
*
Astride
+
k
+
2
],
a23
=
A
[(
m
+
2
)
*
Astride
+
k
+
3
];
__fp16
a30
=
A
[(
m
+
3
)
*
Astride
+
k
],
a31
=
A
[(
m
+
3
)
*
Astride
+
k
+
1
],
a32
=
A
[(
m
+
3
)
*
Astride
+
k
+
2
],
a33
=
A
[(
m
+
3
)
*
Astride
+
k
+
3
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
,
a20
,
a21
,
a22
,
a23
,
a30
,
a31
,
a32
,
a33
;
float16x8_t
b0
,
b1
,
b2
,
b3
;
float16x8_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
UNROLL_OUT
(
loadA2
,
4
)
UNROLL_OUT
(
loadA3
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
4
)
UNROLL_OUT
(
calculate_row1
,
4
)
UNROLL_OUT
(
calculate_row2
,
4
)
...
...
@@ -138,32 +147,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
,
a20
,
a21
,
a22
,
a23
,
a30
,
a31
,
a32
,
a33
;
float16x4_t
b0
,
b1
,
b2
,
b3
;
float16x4_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
UNROLL_OUT
(
loadA2
,
4
)
UNROLL_OUT
(
loadA3
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
4
)
UNROLL_OUT
(
calculate_row1
,
4
)
UNROLL_OUT
(
calculate_row2
,
4
)
...
...
@@ -177,8 +172,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
,
a20
,
a21
,
a22
,
a23
,
a30
,
a31
,
a32
,
a33
;
__fp16
b0
,
b1
,
b2
,
b3
;
__fp16
c0
,
c1
,
c2
,
c3
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -187,18 +180,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
4
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
UNROLL_OUT
(
loadA2
,
4
)
UNROLL_OUT
(
loadA3
,
4
)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0
+=
a00
*
b0
+
a01
*
b1
+
a02
*
b2
+
a03
*
b3
;
c1
+=
a10
*
b0
+
a11
*
b1
+
a12
*
b2
+
a13
*
b3
;
c2
+=
a20
*
b0
+
a21
*
b1
+
a22
*
b2
+
a23
*
b3
;
...
...
@@ -210,32 +191,23 @@ void megdnn::arm_common::gemv_like(
}
for
(;
k
+
2
<=
K
;
k
+=
2
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
],
a11
=
A
[(
m
+
1
)
*
Astride
+
k
+
1
];
__fp16
a20
=
A
[(
m
+
2
)
*
Astride
+
k
],
a21
=
A
[(
m
+
2
)
*
Astride
+
k
+
1
];
__fp16
a30
=
A
[(
m
+
3
)
*
Astride
+
k
],
a31
=
A
[(
m
+
3
)
*
Astride
+
k
+
1
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
,
a10
,
a11
,
a20
,
a21
,
a30
,
a31
;
float16x8_t
b0
,
b1
;
float16x8_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
UNROLL_OUT
(
loadA2
,
2
)
UNROLL_OUT
(
loadA3
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
2
)
UNROLL_OUT
(
calculate_row1
,
2
)
UNROLL_OUT
(
calculate_row2
,
2
)
...
...
@@ -249,31 +221,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
,
a10
,
a11
,
a20
,
a21
,
a30
,
a31
;
float16x4_t
b0
,
b1
;
float16x4_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
UNROLL_OUT
(
loadA2
,
2
)
UNROLL_OUT
(
loadA3
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
2
)
UNROLL_OUT
(
calculate_row1
,
2
)
UNROLL_OUT
(
calculate_row2
,
2
)
...
...
@@ -287,7 +246,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
,
a10
,
a11
,
a20
,
a21
,
a30
,
a31
;
__fp16
b0
,
b1
;
__fp16
c0
,
c1
,
c2
,
c3
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -296,18 +254,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
2
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
UNROLL_OUT
(
loadA2
,
2
)
UNROLL_OUT
(
loadA3
,
2
)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0
+=
a00
*
b0
+
a01
*
b1
;
c1
+=
a10
*
b0
+
a11
*
b1
;
c2
+=
a20
*
b0
+
a21
*
b1
;
...
...
@@ -319,32 +265,23 @@ void megdnn::arm_common::gemv_like(
}
for
(;
k
<
K
;
k
+=
1
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
];
__fp16
a20
=
A
[(
m
+
2
)
*
Astride
+
k
];
__fp16
a30
=
A
[(
m
+
3
)
*
Astride
+
k
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a10
,
a20
,
a30
;
float16x8_t
b0
;
float16x8_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdupq_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
UNROLL_OUT
(
loadA2
,
1
)
UNROLL_OUT
(
loadA3
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vmlaq_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vmlaq_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMAQ_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMAQ_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
1
)
UNROLL_OUT
(
calculate_row1
,
1
)
UNROLL_OUT
(
calculate_row2
,
1
)
...
...
@@ -358,31 +295,18 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a10
,
a20
,
a30
;
float16x4_t
b0
;
float16x4_t
c0
,
c1
,
c2
,
c3
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]);
#define loadA3(i) a3##i = vdup_n_f16(A[(m + 3) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
4
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
UNROLL_OUT
(
loadA2
,
1
)
UNROLL_OUT
(
loadA3
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = vfma_f16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = vfma_f16(c3, b##i, a3##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
#define calculate_row2(i) c2 = VFMA_N_F16(c2, b##i, a2##i);
#define calculate_row3(i) c3 = VFMA_N_F16(c3, b##i, a3##i);
UNROLL_OUT
(
calculate_row0
,
1
)
UNROLL_OUT
(
calculate_row1
,
1
)
UNROLL_OUT
(
calculate_row2
,
1
)
...
...
@@ -396,7 +320,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a10
,
a20
,
a30
;
__fp16
b0
;
__fp16
c0
,
c1
,
c2
,
c3
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -405,18 +328,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
1
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
#define loadA2(i) a2##i = A[(m + 2) * Astride + k + i];
#define loadA3(i) a3##i = A[(m + 3) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
UNROLL_OUT
(
loadA2
,
1
)
UNROLL_OUT
(
loadA3
,
1
)
#undef loadA0
#undef loadA1
#undef loadA2
#undef loadA3
c0
=
c0
+
a00
*
b0
;
c1
=
c1
+
a10
*
b0
;
c2
=
c2
+
a20
*
b0
;
...
...
@@ -432,24 +343,22 @@ void megdnn::arm_common::gemv_like(
memset
(
C
+
m
*
Cstride
,
0
,
2
*
sizeof
(
__fp16
)
*
N
);
for
(;
k
+
4
<=
K
;
k
+=
4
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
],
a02
=
A
[
m
*
Astride
+
k
+
2
],
a03
=
A
[
m
*
Astride
+
k
+
3
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
],
a11
=
A
[(
m
+
1
)
*
Astride
+
k
+
1
],
a12
=
A
[(
m
+
1
)
*
Astride
+
k
+
2
],
a13
=
A
[(
m
+
1
)
*
Astride
+
k
+
3
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
;
float16x8_t
b0
,
b1
,
b2
,
b3
;
float16x8_t
c0
,
c1
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
4
)
UNROLL_OUT
(
calculate_row1
,
4
)
#undef calculate_row0
...
...
@@ -459,23 +368,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
;
float16x4_t
b0
,
b1
,
b2
,
b3
;
float16x4_t
c0
,
c1
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
4
)
UNROLL_OUT
(
calculate_row1
,
4
)
#undef calculate_row0
...
...
@@ -485,7 +387,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
,
a02
,
a03
,
a10
,
a11
,
a12
,
a13
;
__fp16
b0
,
b1
,
b2
,
b3
;
__fp16
c0
,
c1
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -494,12 +395,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
4
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
4
)
UNROLL_OUT
(
loadA1
,
4
)
#undef loadA0
#undef loadA1
c0
+=
a00
*
b0
+
a01
*
b1
+
a02
*
b2
+
a03
*
b3
;
c1
+=
a10
*
b0
+
a11
*
b1
+
a12
*
b2
+
a13
*
b3
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
...
...
@@ -509,24 +404,19 @@ void megdnn::arm_common::gemv_like(
}
for
(;
k
+
2
<=
K
;
k
+=
2
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
],
a11
=
A
[(
m
+
1
)
*
Astride
+
k
+
1
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
,
a10
,
a11
;
float16x8_t
b0
,
b1
;
float16x8_t
c0
,
c1
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
2
)
UNROLL_OUT
(
calculate_row1
,
2
)
#undef calculate_row0
...
...
@@ -536,23 +426,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
,
a10
,
a11
;
float16x4_t
b0
,
b1
;
float16x4_t
c0
,
c1
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
2
)
UNROLL_OUT
(
calculate_row1
,
2
)
#undef calculate_row0
...
...
@@ -562,7 +445,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
,
a10
,
a11
;
__fp16
b0
,
b1
;
__fp16
c0
,
c1
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -571,12 +453,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
2
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
2
)
UNROLL_OUT
(
loadA1
,
2
)
#undef loadA0
#undef loadA1
c0
+=
a00
*
b0
+
a01
*
b1
;
c1
+=
a10
*
b0
+
a11
*
b1
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
...
...
@@ -586,24 +462,19 @@ void megdnn::arm_common::gemv_like(
}
for
(;
k
<
K
;
k
+=
1
)
{
size_t
n
=
0
;
__fp16
a00
=
A
[
m
*
Astride
+
k
];
__fp16
a10
=
A
[(
m
+
1
)
*
Astride
+
k
];
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a10
;
float16x8_t
b0
;
float16x8_t
c0
,
c1
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vmlaq_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMAQ_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
1
)
UNROLL_OUT
(
calculate_row1
,
1
)
#undef calculate_row0
...
...
@@ -613,23 +484,16 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a10
;
float16x4_t
b0
;
float16x4_t
c0
,
c1
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
#define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]);
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
2
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#undef loadA1
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = vfma_f16(c1, b##i, a1##i);
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#define calculate_row1(i) c1 = VFMA_N_F16(c1, b##i, a1##i);
UNROLL_OUT
(
calculate_row0
,
1
)
UNROLL_OUT
(
calculate_row1
,
1
)
#undef calculate_row0
...
...
@@ -639,7 +503,6 @@ void megdnn::arm_common::gemv_like(
#undef vstore
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a10
;
__fp16
b0
;
__fp16
c0
,
c1
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -648,12 +511,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
1
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
#define loadA1(i) a1##i = A[(m + 1) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
1
)
UNROLL_OUT
(
loadA1
,
1
)
#undef loadA0
#undef loadA1
c0
=
c0
+
a00
*
b0
;
c1
=
c1
+
a10
*
b0
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
...
...
@@ -667,48 +524,61 @@ void megdnn::arm_common::gemv_like(
memset
(
C
+
m
*
Cstride
,
0
,
sizeof
(
__fp16
)
*
N
);
for
(;
k
+
4
<=
K
;
k
+=
4
)
{
size_t
n
=
0
;
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
,
a02
,
a03
;
float16x8_t
b0
,
b1
,
b2
,
b3
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
],
a02
=
A
[
m
*
Astride
+
k
+
2
],
a03
=
A
[
m
*
Astride
+
k
+
3
];
{
#if !defined(__aarch64__)
float16x8_t
va00
=
vdupq_n_f16
(
a00
),
va01
=
vdupq_n_f16
(
a01
),
va02
=
vdupq_n_f16
(
a02
),
va03
=
vdupq_n_f16
(
a03
);
#endif
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
b0
,
b1
,
b2
,
b3
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
4
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
4
)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
,
a02
,
a03
;
float16x4_t
b0
,
b1
,
b2
,
b3
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
4
)
UNROLL_OUT
(
loadA0
,
4
)
{
#if !defined(__aarch64__)
float16x4_t
va00
=
vdup_n_f16
(
a00
),
va01
=
vdup_n_f16
(
a01
),
va02
=
vdup_n_f16
(
a02
),
va03
=
vdup_n_f16
(
a03
);
#endif
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
b0
,
b1
,
b2
,
b3
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
4
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
4
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
4
)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
,
a02
,
a03
;
__fp16
b0
,
b1
,
b2
,
b3
;
__fp16
c0
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -717,9 +587,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
4
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[m * Astride + k + i];
UNROLL_OUT
(
loadA0
,
4
)
#undef loadA0
c0
+=
a00
*
b0
+
a01
*
b1
+
a02
*
b2
+
a03
*
b3
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT
(
vstore
,
1
)
...
...
@@ -727,49 +594,59 @@ void megdnn::arm_common::gemv_like(
}
}
for
(;
k
+
2
<=
K
;
k
+=
2
)
{
__fp16
a00
=
A
[
m
*
Astride
+
k
],
a01
=
A
[
m
*
Astride
+
k
+
1
];
size_t
n
=
0
;
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
,
a01
;
float16x8_t
b0
,
b1
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
{
#if !defined(__aarch64__)
float16x8_t
va00
=
vdupq_n_f16
(
a00
),
va01
=
vdupq_n_f16
(
a01
);
#endif
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
b0
,
b1
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
2
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
2
)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
,
a01
;
float16x4_t
b0
,
b1
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
2
)
UNROLL_OUT
(
loadA0
,
2
)
{
#if !defined(__aarch64__)
float16x4_t
va00
=
vdup_n_f16
(
a00
),
va01
=
vdup_n_f16
(
a01
);
#endif
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
b0
,
b1
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
2
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
2
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
2
)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
,
a01
;
__fp16
b0
,
b1
;
__fp16
c0
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -778,9 +655,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
2
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
2
)
#undef loadA0
c0
+=
a00
*
b0
+
a01
*
b1
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT
(
vstore
,
1
)
...
...
@@ -789,48 +663,58 @@ void megdnn::arm_common::gemv_like(
}
for
(;
k
<
K
;
k
+=
1
)
{
size_t
n
=
0
;
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
a00
;
float16x8_t
b0
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
__fp16
a00
=
A
[
m
*
Astride
+
k
];
{
#if !defined(__aarch64__)
float16x8_t
va00
=
vdupq_n_f16
(
a00
);
#endif
for
(;
n
+
8
<=
N
;
n
+=
8
)
{
float16x8_t
b0
;
float16x8_t
c0
;
#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vmlaq_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
1
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMAQ_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfmaq_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
1
)
#undef calculate_row0
#define vstore(i) vst1q_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
a00
;
float16x4_t
b0
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
#define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
1
)
UNROLL_OUT
(
loadA0
,
1
)
{
#if !defined(__aarch64__)
float16x4_t
va00
=
vdup_n_f16
(
a00
);
#endif
for
(;
n
+
4
<=
N
;
n
+=
4
)
{
float16x4_t
b0
;
float16x4_t
c0
;
#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n);
#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n);
UNROLL_OUT
(
loadC
,
1
)
UNROLL_OUT
(
loadB
,
1
)
#undef loadB
#undef loadC
#undef loadA0
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, a0##i);
UNROLL_OUT
(
calculate_row0
,
1
)
#if defined(__aarch64__)
#define calculate_row0(i) c0 = VFMA_N_F16(c0, b##i, a0##i);
#else
#define calculate_row0(i) c0 = vfma_f16(c0, b##i, va0##i);
#endif
UNROLL_OUT
(
calculate_row0
,
1
)
#undef calculate_row0
#define vstore(i) vst1_f16(C + (m + i) * Cstride + n, c##i);
UNROLL_OUT
(
vstore
,
1
)
UNROLL_OUT
(
vstore
,
1
)
#undef vstore
}
}
for
(;
n
<
N
;
n
+=
1
)
{
__fp16
a00
;
__fp16
b0
;
__fp16
c0
;
#define loadC(i) c##i = C[(m + i) * Cstride + n];
...
...
@@ -839,9 +723,6 @@ void megdnn::arm_common::gemv_like(
UNROLL_OUT
(
loadB
,
1
)
#undef loadB
#undef loadC
#define loadA0(i) a0##i = A[(m + 0) * Astride + k + i];
UNROLL_OUT
(
loadA0
,
1
)
#undef loadA0
c0
=
c0
+
a00
*
b0
;
#define vstore(i) C[(m + i) * Cstride + n] = c##i;
UNROLL_OUT
(
vstore
,
1
)
...
...
@@ -850,6 +731,10 @@ void megdnn::arm_common::gemv_like(
}
}
}
#undef VFMA_N_F16
#undef VFMAQ_N_F16
bool
megdnn
::
arm_common
::
is_hgemv_preferred
(
bool
transposeA
,
bool
transposeB
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
/*LDA*/
,
size_t
LDB
,
size_t
/*LDC*/
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录