Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
9939c334
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9939c334
编写于
5月 31, 2018
作者:
吴
吴承辉
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'gemm' into 'master'
Optimize gemm x84 (v8/v7) gemv v7 See merge request !544
上级
4aa83602
24fade6d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
851 addition
and
506 deletion
+851
-506
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+789
-481
mace/kernels/gemm.h
mace/kernels/gemm.h
+1
-0
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+61
-25
未找到文件。
mace/kernels/gemm.cc
浏览文件 @
9939c334
...
@@ -20,11 +20,13 @@
...
@@ -20,11 +20,13 @@
#include <arm_neon.h>
#include <arm_neon.h>
#endif
#endif
#include "mace/core/macros.h"
#include "mace/kernels/gemm.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
#include "mace/utils/logging.h"
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
...
@@ -47,13 +49,36 @@ inline void GemmBlock(const float *A,
...
@@ -47,13 +49,36 @@ inline void GemmBlock(const float *A,
}
}
}
}
// TODO(liyin): may need implement 883 since RGB
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \
c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);
#endif
#endif
inline
void
Gemm884
(
const
float
*
a_ptr
,
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_k
,
index_t
stride_w
,
index_t
stride_w
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
&& defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
a15
;
a15
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
...
@@ -94,24 +119,25 @@ inline void Gemm884(const float *a_ptr,
...
@@ -94,24 +119,25 @@ inline void Gemm884(const float *a_ptr,
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_w
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_w
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_w
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_w
);
#define MACE_CONV_1x1_REG_CAL(RC, RA, RAN) \
#if defined(__aarch64__)
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#else
MACE_CONV_1x1_REG_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_CONV_1x1_REG_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_CONV_1x1_REG_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_CONV_1x1_REG_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_CONV_1x1_REG_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_CONV_1x1_REG_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_CONV_1x1_REG_CAL
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
MACE_CONV_1x1_REG_CAL
(
7
,
14
,
15
);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
...
@@ -121,12 +147,428 @@ inline void Gemm884(const float *a_ptr,
...
@@ -121,12 +147,428 @@ inline void Gemm884(const float *a_ptr,
vst1q_f32
(
c_ptr
+
5
*
stride_w
,
c5
);
vst1q_f32
(
c_ptr
+
5
*
stride_w
,
c5
);
vst1q_f32
(
c_ptr
+
6
*
stride_w
,
c6
);
vst1q_f32
(
c_ptr
+
6
*
stride_w
,
c6
);
vst1q_f32
(
c_ptr
+
7
*
stride_w
,
c7
);
vst1q_f32
(
c_ptr
+
7
*
stride_w
,
c7
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm184
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
MACE_UNUSED
(
stride_k
);
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
1
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm284
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
2
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm384
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
,
c2
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_w
,
c2
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
3
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm484
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
,
c2
,
c3
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_w
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_w
,
c3
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm584
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_w
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_w
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_w
,
c4
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
5
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm684
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
,
c5
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_w
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_w
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_w
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_w
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_w
,
c5
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
6
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
Gemm784
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
,
c5
,
c6
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
+
4
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_k
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_k
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_w
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_w
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_w
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_w
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_w
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_w
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_w
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_w
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_w
,
c5
);
vst1q_f32
(
c_ptr
+
6
*
stride_w
,
c6
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
7
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
#endif
}
inline
void
GemmX84
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
,
int
row
)
{
switch
(
row
)
{
case
1
:
Gemm184
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
2
:
Gemm284
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
3
:
Gemm384
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
4
:
Gemm484
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
5
:
Gemm584
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
6
:
Gemm684
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
7
:
Gemm784
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
case
8
:
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
inline
void
GemmTile
(
const
float
*
A
,
inline
void
GemmTile
(
const
float
*
A
,
const
float
*
B
,
const
float
*
B
,
const
index_t
height
,
const
index_t
height
,
...
@@ -137,18 +579,15 @@ inline void GemmTile(const float *A,
...
@@ -137,18 +579,15 @@ inline void GemmTile(const float *A,
float
*
C
)
{
float
*
C
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
index_t
h
,
w
,
k
;
index_t
h
,
w
,
k
;
#endif
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
#if
def __clang__
#if
defined(__aarch64__) && defined(__clang__)
int
nw
=
width
>>
2
;
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
if
(
nw
>
0
)
{
// load A
// load A
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
a15
;
a14
,
a15
;
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
...
@@ -378,30 +817,19 @@ inline void GemmTile(const float *A,
...
@@ -378,30 +817,19 @@ inline void GemmTile(const float *A,
"w"
(
a11
),
// 47
"w"
(
a11
),
// 47
"w"
(
a13
),
// 48
"w"
(
a13
),
// 48
"w"
(
a15
)
// 49
"w"
(
a15
)
// 49
:
"cc"
,
"memory"
,
:
"cc"
,
"memory"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v16"
,
"v23"
,
"v24"
,
"v25"
);
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
);
w
=
(
width
>>
2
)
<<
2
;
w
=
(
width
>>
2
)
<<
2
;
}
}
#else
// gcc
#else
// gcc || armv7a
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
}
}
#endif // clang
#endif // clang
&& armv8a
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
...
@@ -411,154 +839,37 @@ inline void GemmTile(const float *A,
...
@@ -411,154 +839,37 @@ inline void GemmTile(const float *A,
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_k
,
stride_w
,
c_ptr
);
b_ptr
,
8
,
K
-
k
,
width
,
stride_k
,
stride_w
,
c_ptr
);
}
}
}
}
if
(
h
<
height
)
{
if
(
h
<
height
)
{
// TODO(liyin): may use Gemm444
index_t
remain_h
=
height
-
h
;
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
);
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
const
float
*
b_ptr
=
B
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
b_ptr
,
height
-
h
,
K
,
width
,
stride_k
,
stride_w
,
c_ptr
);
}
#else
#if defined(MACE_ENABLE_NEON) // armv7
w
=
(
width
>>
2
)
<<
2
;
for
(
h
=
0
;
h
+
3
<
height
;
h
+=
4
)
{
for
(
k
=
0
;
k
+
3
<
K
;
k
+=
4
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
int
nw
=
width
>>
2
;
index_t
w
;
if
(
nw
>
0
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
// load A
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float32x2_t
a00
,
a01
,
a10
,
a11
,
a20
,
a21
,
a30
,
a31
;
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
a00
=
vld1_f32
(
a_ptr
);
GemmX84
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
,
remain_h
);
a01
=
vld1_f32
(
a_ptr
+
2
);
a10
=
vld1_f32
(
a_ptr
+
1
*
stride_k
);
a11
=
vld1_f32
(
a_ptr
+
1
*
stride_k
+
2
);
a20
=
vld1_f32
(
a_ptr
+
2
*
stride_k
);
a21
=
vld1_f32
(
a_ptr
+
2
*
stride_k
+
2
);
a30
=
vld1_f32
(
a_ptr
+
3
*
stride_k
);
a31
=
vld1_f32
(
a_ptr
+
3
*
stride_k
+
2
);
const
float
*
b_ptr0
=
B
+
k
*
stride_w
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_w
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_w
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_w
;
float
*
c_ptr0
=
C
+
h
*
stride_w
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_w
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_w
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_w
;
// TODO(liyin): asm v7 prefetch and load optimization
while
(
nw
--
)
{
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
;
c0
=
vld1q_f32
(
c_ptr0
);
b0
=
vld1q_f32
(
b_ptr0
);
b1
=
vld1q_f32
(
b_ptr1
);
b2
=
vld1q_f32
(
b_ptr2
);
b3
=
vld1q_f32
(
b_ptr3
);
c1
=
vld1q_f32
(
c_ptr1
);
c2
=
vld1q_f32
(
c_ptr2
);
c3
=
vld1q_f32
(
c_ptr3
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a00
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b1
,
a00
,
1
);
c0
=
vmlaq_lane_f32
(
c0
,
b2
,
a01
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b3
,
a01
,
1
);
vst1q_f32
(
c_ptr0
,
c0
);
c1
=
vmlaq_lane_f32
(
c1
,
b0
,
a10
,
0
);
c1
=
vmlaq_lane_f32
(
c1
,
b1
,
a10
,
1
);
c1
=
vmlaq_lane_f32
(
c1
,
b2
,
a11
,
0
);
c1
=
vmlaq_lane_f32
(
c1
,
b3
,
a11
,
1
);
vst1q_f32
(
c_ptr1
,
c1
);
c2
=
vmlaq_lane_f32
(
c2
,
b0
,
a20
,
0
);
c2
=
vmlaq_lane_f32
(
c2
,
b1
,
a20
,
1
);
c2
=
vmlaq_lane_f32
(
c2
,
b2
,
a21
,
0
);
c2
=
vmlaq_lane_f32
(
c2
,
b3
,
a21
,
1
);
vst1q_f32
(
c_ptr2
,
c2
);
c3
=
vmlaq_lane_f32
(
c3
,
b0
,
a30
,
0
);
c3
=
vmlaq_lane_f32
(
c3
,
b1
,
a30
,
1
);
c3
=
vmlaq_lane_f32
(
c3
,
b2
,
a31
,
0
);
c3
=
vmlaq_lane_f32
(
c3
,
b3
,
a31
,
1
);
vst1q_f32
(
c_ptr3
,
c3
);
b_ptr0
+=
4
;
b_ptr1
+=
4
;
b_ptr2
+=
4
;
b_ptr3
+=
4
;
c_ptr0
+=
4
;
c_ptr1
+=
4
;
c_ptr2
+=
4
;
c_ptr3
+=
4
;
}
}
}
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
4
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
K
-
k
,
width
,
stride_k
,
stride_w
,
b_ptr
,
4
,
K
-
k
,
width
,
stride_k
,
stride_w
,
c_ptr
);
c_ptr
);
}
}
}
}
if
(
h
<
height
)
{
#else
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
);
const
float
*
b_ptr
=
B
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
b_ptr
,
height
-
h
,
K
,
width
,
stride_k
,
stride_w
,
c_ptr
);
}
#else // cpu
GemmBlock
(
A
,
B
,
height
,
K
,
width
,
stride_k
,
stride_w
,
C
);
GemmBlock
(
A
,
B
,
height
,
K
,
width
,
stride_k
,
stride_w
,
C
);
#endif // armv7
#endif // MACE_ENABLE_NEON
#endif // aarch64
}
}
}
// namespace
}
// namespace
...
@@ -602,29 +913,25 @@ void Gemm(const float *A,
...
@@ -602,29 +913,25 @@ void Gemm(const float *A,
const
index_t
ih_begin
=
bh
*
block_size
;
const
index_t
ih_begin
=
bh
*
block_size
;
const
index_t
ih_end
=
const
index_t
ih_end
=
bh
*
block_size
+
(
bh
==
block_tile_height
-
1
&&
remain_height
>
0
bh
*
block_size
+
(
bh
==
block_tile_height
-
1
&&
remain_height
>
0
?
remain_height
:
block_size
);
?
remain_height
:
block_size
);
const
index_t
iw_begin
=
bw
*
block_size
;
const
index_t
iw_begin
=
bw
*
block_size
;
const
index_t
iw_end
=
const
index_t
iw_end
=
bw
*
block_size
bw
*
block_size
+
(
bw
==
block_tile_width
-
1
&&
remain_width
>
0
+
(
bw
==
block_tile_width
-
1
&&
remain_width
>
0
?
remain_width
?
remain_width
:
block_size
);
:
block_size
);
for
(
index_t
bk
=
0
;
bk
<
block_tile_k
;
++
bk
)
{
for
(
index_t
bk
=
0
;
bk
<
block_tile_k
;
++
bk
)
{
const
index_t
ik_begin
=
bk
*
block_size
;
const
index_t
ik_begin
=
bk
*
block_size
;
const
index_t
ik_end
=
const
index_t
ik_end
=
bk
*
block_size
bk
*
block_size
+
+
(
bk
==
block_tile_k
-
1
&&
remain_k
>
0
?
remain_k
(
bk
==
block_tile_k
-
1
&&
remain_k
>
0
?
remain_k
:
block_size
);
:
block_size
);
// inside block:
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile
(
a_base
+
(
ih_begin
*
K
+
ik_begin
),
GemmTile
(
a_base
+
(
ih_begin
*
K
+
ik_begin
),
b_base
+
(
ik_begin
*
width
+
iw_begin
),
b_base
+
(
ik_begin
*
width
+
iw_begin
),
ih_end
-
ih_begin
,
ih_end
-
ih_begin
,
ik_end
-
ik_begin
,
iw_end
-
iw_begin
,
K
,
width
,
ik_end
-
ik_begin
,
iw_end
-
iw_begin
,
K
,
width
,
c_base
+
(
ih_begin
*
width
+
iw_begin
));
c_base
+
(
ih_begin
*
width
+
iw_begin
));
}
// bk
}
// bk
}
// bw
}
// bw
...
@@ -635,59 +942,60 @@ void Gemm(const float *A,
...
@@ -635,59 +942,60 @@ void Gemm(const float *A,
// A: height x K, B: K x width, C: height x width
// A: height x K, B: K x width, C: height x width
void
GemmRef
(
const
float
*
A
,
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
float
*
B
,
const
index_t
batch
,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
float
*
C
)
{
float
*
C
)
{
memset
(
C
,
0
,
sizeof
(
float
)
*
height
*
width
);
memset
(
C
,
0
,
sizeof
(
float
)
*
batch
*
height
*
width
);
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
j
=
0
;
j
<
width
;
++
j
)
{
for
(
index_t
i
=
0
;
i
<
height
;
++
i
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
j
=
0
;
j
<
width
;
++
j
)
{
C
[
i
*
width
+
j
]
+=
A
[
i
*
K
+
k
]
*
B
[
k
*
width
+
j
];
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
C
[(
b
*
height
+
i
)
*
width
+
j
]
+=
A
[(
b
*
height
+
i
)
*
K
+
k
]
*
B
[(
b
*
K
+
k
)
*
width
+
j
];
}
}
}
}
}
}
}
}
}
void
GemvRef
(
const
float
*
m_ptr
,
void
GemvRef
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
width
,
const
index_t
height
,
const
index_t
height
,
float
*
out_ptr
)
{
float
*
out_ptr
)
{
memset
(
out_ptr
,
0
,
sizeof
(
float
)
*
height
*
batch
);
memset
(
out_ptr
,
0
,
batch
*
height
*
sizeof
(
float
));
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
out_ptr
[
h
+
b
*
height
]
+=
v_ptr
[
w
+
b
*
width
]
*
m_ptr
[
h
*
width
+
w
];
out_ptr
[
b
*
height
+
h
]
+=
v_ptr
[
b
*
width
+
w
]
*
m_ptr
[
h
*
width
+
w
];
}
}
}
}
}
}
}
}
//
M: height x width, Vin: width x 1, Vout: height x 1
//
TODO(liyin): batched gemv can be transformed to gemm (w/ transpose)
void
Gemv
(
const
float
*
m_ptr
,
void
Gemv
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
const
index_t
batch
,
const
index_t
width
,
const
index_t
width
,
const
index_t
height
,
const
index_t
height
,
float
*
out_ptr
)
{
float
*
out_ptr
)
{
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
index_t
height_d4
=
height
>>
2
;
// TODO(liyin/wch): try height tiling = 8
index_t
width_d4
=
width
>>
2
;
#pragma omp parallel for collapse(2)
index_t
remain_w
=
width
-
(
width_d4
<<
2
);
index_t
remain_h
=
height
-
(
height_d4
<<
2
);
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
#pragma omp parallel for
for
(
index_t
h
=
0
;
h
<
height
;
h
+=
4
)
{
for
(
index_t
h
=
0
;
h
<
height_d4
;
++
h
)
{
if
(
h
+
3
<
height
)
{
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
*
4
;
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
;
const
float
*
m_ptr1
=
m_ptr0
+
width
;
const
float
*
m_ptr1
=
m_ptr0
+
width
;
const
float
*
m_ptr2
=
m_ptr1
+
width
;
const
float
*
m_ptr2
=
m_ptr1
+
width
;
const
float
*
m_ptr3
=
m_ptr2
+
width
;
const
float
*
m_ptr3
=
m_ptr2
+
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
float
*
out_ptr0
=
out_ptr
+
h
*
4
+
b
*
height
;
float
*
out_ptr0
=
out_ptr
+
b
*
height
+
h
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vv
;
float32x4_t
vv
;
...
@@ -697,7 +1005,8 @@ void Gemv(const float *m_ptr,
...
@@ -697,7 +1005,8 @@ void Gemv(const float *m_ptr,
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vm0
=
vld1q_f32
(
m_ptr0
);
vm1
=
vld1q_f32
(
m_ptr1
);
vm1
=
vld1q_f32
(
m_ptr1
);
vm2
=
vld1q_f32
(
m_ptr2
);
vm2
=
vld1q_f32
(
m_ptr2
);
...
@@ -721,7 +1030,7 @@ void Gemv(const float *m_ptr,
...
@@ -721,7 +1030,7 @@ void Gemv(const float *m_ptr,
float
sum3
=
vaddvq_f32
(
vsum3
);
float
sum3
=
vaddvq_f32
(
vsum3
);
// handle remaining w
// handle remaining w
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
for
(;
w
<
width
;
++
w
)
{
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum1
+=
m_ptr1
[
0
]
*
v_ptr0
[
0
];
sum1
+=
m_ptr1
[
0
]
*
v_ptr0
[
0
];
sum2
+=
m_ptr2
[
0
]
*
v_ptr0
[
0
];
sum2
+=
m_ptr2
[
0
]
*
v_ptr0
[
0
];
...
@@ -736,16 +1045,13 @@ void Gemv(const float *m_ptr,
...
@@ -736,16 +1045,13 @@ void Gemv(const float *m_ptr,
*
out_ptr0
++
=
sum1
;
*
out_ptr0
++
=
sum1
;
*
out_ptr0
++
=
sum2
;
*
out_ptr0
++
=
sum2
;
*
out_ptr0
++
=
sum3
;
*
out_ptr0
++
=
sum3
;
}
}
else
{
for
(
index_t
hh
=
h
;
hh
<
height
;
++
hh
)
{
// handle remaining h
index_t
remain_start_height
=
height_d4
<<
2
;
#pragma omp parallel for
for
(
index_t
h
=
0
;
h
<
remain_h
;
++
h
)
{
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
const
float
*
m_ptr0
=
m_ptr
+
(
h
+
remain_start_height
)
*
width
;
const
float
*
m_ptr0
=
m_ptr
+
hh
*
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
for
(
index_t
w
=
0
;
w
<
width_d4
;
++
w
)
{
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
float32x4_t
vm
=
vld1q_f32
(
m_ptr0
);
float32x4_t
vm
=
vld1q_f32
(
m_ptr0
);
float32x4_t
vv
=
vld1q_f32
(
v_ptr0
);
float32x4_t
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm
,
vv
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm
,
vv
);
...
@@ -753,14 +1059,16 @@ void Gemv(const float *m_ptr,
...
@@ -753,14 +1059,16 @@ void Gemv(const float *m_ptr,
v_ptr0
+=
4
;
v_ptr0
+=
4
;
}
}
float
sum
=
vaddvq_f32
(
vsum0
);
float
sum
=
vaddvq_f32
(
vsum0
);
for
(
index_t
w
=
0
;
w
<
remain_w
;
++
w
)
{
for
(;
w
<
width
;
++
w
)
{
sum
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
m_ptr0
++
;
v_ptr0
++
;
v_ptr0
++
;
}
}
out_ptr
[
remain_start_height
+
h
+
b
*
height
]
=
sum
;
out_ptr
[
b
*
height
+
hh
]
=
sum
;
}
}
}
}
// if
}
// h
}
// b
#else
#else
GemvRef
(
m_ptr
,
v_ptr
,
batch
,
width
,
height
,
out_ptr
);
GemvRef
(
m_ptr
,
v_ptr
,
batch
,
width
,
height
,
out_ptr
);
#endif
#endif
...
...
mace/kernels/gemm.h
浏览文件 @
9939c334
...
@@ -34,6 +34,7 @@ void Gemm(const float *A,
...
@@ -34,6 +34,7 @@ void Gemm(const float *A,
void
GemmRef
(
const
float
*
A
,
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
float
*
B
,
const
index_t
batch
,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
...
...
mace/kernels/gemm_test.cc
浏览文件 @
9939c334
...
@@ -21,62 +21,98 @@
...
@@ -21,62 +21,98 @@
namespace
mace
{
namespace
mace
{
TEST
(
GEMMTest
,
gemm
)
{
namespace
{
index_t
N
=
17
;
index_t
M
=
33
;
void
GemmTest
(
index_t
batch
,
index_t
N
,
index_t
K
,
index_t
M
)
{
index_t
K
=
64
;
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
batch
*
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
batch
*
K
*
M
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
K
*
M
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
batch
*
N
*
M
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
N
*
M
]);
std
::
unique_ptr
<
float
[]
>
C_ref
(
new
float
[
batch
*
N
*
M
]);
std
::
unique_ptr
<
float
[]
>
C_ref
(
new
float
[
N
*
M
]);
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
A
.
get
(),
A
.
get
()
+
N
*
K
,
std
::
generate
(
A
.
get
(),
A
.
get
()
+
batch
*
N
*
K
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
return
nd
(
gen
);
});
});
std
::
generate
(
B
.
get
(),
B
.
get
()
+
K
*
M
,
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
K
*
M
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
return
nd
(
gen
);
});
});
kernels
::
Gemm
(
A
.
get
(),
B
.
get
(),
1
,
N
,
K
,
M
,
C
.
get
());
kernels
::
Gemm
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C
.
get
());
kernels
::
GemmRef
(
A
.
get
(),
B
.
get
(),
N
,
K
,
M
,
C_ref
.
get
());
kernels
::
GemmRef
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C_ref
.
get
());
for
(
int
i
=
0
;
i
<
N
*
M
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch
*
N
*
M
;
++
i
)
{
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
}
}
}
}
TEST
(
GEMMTest
,
gemv
)
{
void
GemvTest
(
index_t
batch
,
index_t
N
,
index_t
M
)
{
index_t
N
=
17
;
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
N
*
M
]);
index_t
K
=
63
;
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
batch
*
M
]);
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
batch
*
N
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
K
]);
std
::
unique_ptr
<
float
[]
>
C_ref
(
new
float
[
batch
*
N
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
N
]);
std
::
unique_ptr
<
float
[]
>
C_ref
(
new
float
[
N
]);
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
A
.
get
(),
A
.
get
()
+
N
*
K
,
std
::
generate
(
A
.
get
(),
A
.
get
()
+
N
*
M
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
return
nd
(
gen
);
});
});
std
::
generate
(
B
.
get
(),
B
.
get
()
+
K
,
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
M
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
return
nd
(
gen
);
});
});
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
1
,
K
,
N
,
C
.
get
());
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
1
,
K
,
N
,
C_ref
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C_ref
.
get
());
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch
*
N
;
++
i
)
{
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
}
}
}
}
}
// namespace
TEST
(
GEMMTest
,
AlignedWithoutBatch
)
{
GemmTest
(
1
,
1
,
64
,
128
);
GemmTest
(
1
,
2
,
64
,
128
);
GemmTest
(
1
,
3
,
64
,
128
);
GemmTest
(
1
,
4
,
64
,
128
);
GemmTest
(
1
,
5
,
64
,
128
);
GemmTest
(
1
,
6
,
64
,
128
);
GemmTest
(
1
,
7
,
64
,
128
);
GemmTest
(
1
,
17
,
64
,
128
);
}
TEST
(
GEMMTest
,
UnalignedWithoutBatch
)
{
GemmTest
(
1
,
1
,
63
,
127
);
GemmTest
(
1
,
2
,
63
,
127
);
GemmTest
(
1
,
3
,
63
,
127
);
GemmTest
(
1
,
4
,
63
,
127
);
GemmTest
(
1
,
5
,
63
,
127
);
GemmTest
(
1
,
6
,
63
,
127
);
GemmTest
(
1
,
7
,
63
,
127
);
GemmTest
(
1
,
17
,
63
,
127
);
}
TEST
(
GEMMTest
,
UnalignedWithBatch
)
{
GemmTest
(
3
,
1
,
63
,
127
);
GemmTest
(
3
,
2
,
63
,
127
);
GemmTest
(
3
,
3
,
63
,
127
);
GemmTest
(
3
,
4
,
63
,
127
);
GemmTest
(
3
,
5
,
63
,
127
);
GemmTest
(
3
,
6
,
63
,
127
);
GemmTest
(
3
,
7
,
63
,
127
);
GemmTest
(
3
,
17
,
63
,
127
);
}
TEST
(
GEMMTest
,
gemv
)
{
GemvTest
(
1
,
17
,
63
);
GemvTest
(
3
,
17
,
63
);
}
}
// namespace mace
}
// namespace mace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录