Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
a025ac02
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a025ac02
编写于
6月 12, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize gemm v7
上级
dad3d11a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
487 addition
and
114 deletion
+487
-114
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+474
-110
mace/ops/shape_test.cc
mace/ops/shape_test.cc
+3
-1
mace/ops/strided_slice_test.cc
mace/ops/strided_slice_test.cc
+10
-3
未找到文件。
mace/kernels/gemm.cc
浏览文件 @
a025ac02
...
@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A,
...
@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A,
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL
(RC, RA, RAN)
\
#define MACE_GEMM_PART_CAL
_8(RC, RA, RAN)
\
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
...
@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A,
...
@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A,
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#else
#define MACE_GEMM_PART_CAL
(RC, RA, RAN)
\
#define MACE_GEMM_PART_CAL
_8(RC, RA, RAN)
\
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
...
@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A,
...
@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A,
#endif
#endif
#endif
#endif
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3);
#else
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
#endif
#endif
inline
void
Gemm144
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED
(
stride_a
);
MACE_UNUSED
(
stride_c
);
float32x4_t
a0
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
;
a0
=
vld1q_f32
(
a_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
MACE_GEMM_PART_CAL_4
(
0
);
vst1q_f32
(
c_ptr
,
c0
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
1
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm244
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
2
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm344
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
3
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm444
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm544
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
MACE_GEMM_PART_CAL_4
(
4
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_c
,
c4
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
5
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm644
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
,
c5
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
a5
=
vld1q_f32
(
a_ptr
+
5
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
MACE_GEMM_PART_CAL_4
(
4
);
MACE_GEMM_PART_CAL_4
(
5
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_c
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_c
,
c5
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
6
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
GemmX44
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
,
int
row
)
{
switch
(
row
)
{
case
1
:
Gemm144
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
2
:
Gemm244
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
3
:
Gemm344
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
4
:
Gemm444
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
5
:
Gemm544
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
6
:
Gemm644
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
inline
void
Gemm884
(
const
float
*
a_ptr
,
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_a
,
...
@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr,
...
@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr,
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_c
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL_8
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL_8
(
7
,
14
,
15
);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr,
...
@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr,
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
#else
#else
...
@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr,
...
@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr,
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr,
...
@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr,
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr,
...
@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr,
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr,
...
@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr,
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr,
...
@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr,
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr,
...
@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr,
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL_8
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
@@ -589,9 +806,19 @@ inline void GemmTile(const float *A,
...
@@ -589,9 +806,19 @@ inline void GemmTile(const float *A,
const
index_t
stride_c
,
const
index_t
stride_c
,
float
*
C
)
{
float
*
C
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
index_t
h
,
w
,
k
;
index_t
h
=
0
;
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
index_t
w
=
0
;
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
index_t
k
=
0
;
#if defined(__aarch64__)
int
reg_height_tile
=
8
;
int
reg_K_tile
=
8
;
#else
int
reg_height_tile
=
6
;
int
reg_K_tile
=
4
;
#endif
for
(
h
=
0
;
h
<
height
-
reg_height_tile
+
1
;
h
+=
reg_height_tile
)
{
for
(
k
=
0
;
k
<
K
-
reg_K_tile
+
1
;
k
+=
reg_K_tile
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
#if defined(__aarch64__) && defined(__clang__)
#if defined(__aarch64__) && defined(__clang__)
int
nw
=
width
>>
2
;
int
nw
=
width
>>
2
;
...
@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A,
...
@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A,
w
=
(
width
>>
2
)
<<
2
;
w
=
(
width
>>
2
)
<<
2
;
}
}
#el
se // gcc || armv7a
#el
if defined(__aarch64__) // gcc
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
#endif // clang && armv8a
#else // armv7
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
a5
=
vld1q_f32
(
a_ptr
+
5
*
stride_a
);
const
float
*
b_ptr0
=
B
+
k
*
stride_b
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_b
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_b
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_b
;
float
*
c_ptr0
=
C
+
h
*
stride_c
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_c
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_c
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_c
;
float
*
c_ptr4
=
C
+
(
h
+
4
)
*
stride_c
;
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_c
;
asm
volatile
(
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"pld [%1, #128]
\n
"
"vld1.f32 {d16-d17}, [%1]
\n
"
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"0:
\n
"
"pld [%3, #128]
\n
"
"vld1.f32 {d20-d21}, [%3]
\n
"
"pld [%4, #128]
\n
"
"vld1.f32 {d22-d23}, [%4]
\n
"
"pld [%5, #128]
\n
"
"vld1.f32 {d24-d25}, [%5]
\n
"
"pld [%6, #128]
\n
"
"vld1.f32 {d26-d27}, [%6]
\n
"
"pld [%8, #128]
\n
"
"vld1.f32 {d14-d15}, [%8]!
\n
"
"vmla.f32 q8, q6, %e22[0]
\n
"
"vmla.f32 q9, q6, %e23[0]
\n
"
"vmla.f32 q10, q6, %e24[0]
\n
"
"vmla.f32 q11, q6, %e25[0]
\n
"
"vmla.f32 q12, q6, %e26[0]
\n
"
"vmla.f32 q13, q6, %e27[0]
\n
"
"pld [%9, #128]
\n
"
"vld1.f32 {d12-d13}, [%9]!
\n
"
"vmla.f32 q8, q7, %e22[1]
\n
"
"vmla.f32 q9, q7, %e23[1]
\n
"
"vmla.f32 q10, q7, %e24[1]
\n
"
"vmla.f32 q11, q7, %e25[1]
\n
"
"vmla.f32 q12, q7, %e26[1]
\n
"
"vmla.f32 q13, q7, %e27[1]
\n
"
"pld [%10, #128]
\n
"
"vld1.f32 {d14-d15}, [%10]!
\n
"
"vmla.f32 q8, q6, %f22[0]
\n
"
"vmla.f32 q9, q6, %f23[0]
\n
"
"vmla.f32 q10, q6, %f24[0]
\n
"
"vmla.f32 q11, q6, %f25[0]
\n
"
"vmla.f32 q12, q6, %f26[0]
\n
"
"vmla.f32 q13, q6, %f27[0]
\n
"
"vmla.f32 q8, q7, %f22[1]
\n
"
"vmla.f32 q9, q7, %f23[1]
\n
"
"vmla.f32 q10, q7, %f24[1]
\n
"
"vmla.f32 q11, q7, %f25[1]
\n
"
"vmla.f32 q12, q7, %f26[1]
\n
"
"vmla.f32 q13, q7, %f27[1]
\n
"
"vst1.f32 {d16-d17}, [%1]!
\n
"
"vst1.f32 {d18-d19}, [%2]!
\n
"
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"vst1.f32 {d20-d21}, [%3]!
\n
"
"vst1.f32 {d22-d23}, [%4]!
\n
"
"pld [%1, #128]
\n
"
"vld1.f32 {d16-d17}, [%1]
\n
"
"vst1.f32 {d24-d25}, [%5]!
\n
"
"vst1.f32 {d26-d27}, [%6]!
\n
"
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"subs %0, #1
\n
"
"bne 0b
\n
"
:
"=r"
(
nw
),
// 0
"=r"
(
c_ptr0
),
// 1
"=r"
(
c_ptr1
),
// 2
"=r"
(
c_ptr2
),
// 3
"=r"
(
c_ptr3
),
// 4
"=r"
(
c_ptr4
),
// 5
"=r"
(
c_ptr5
),
// 6
"=r"
(
b_ptr0
),
// 7
"=r"
(
b_ptr1
),
// 8
"=r"
(
b_ptr2
),
// 9
"=r"
(
b_ptr3
)
// 10
:
"0"
(
nw
),
// 11
"1"
(
c_ptr0
),
// 12
"2"
(
c_ptr1
),
// 13
"3"
(
c_ptr2
),
// 14
"4"
(
c_ptr3
),
// 15
"5"
(
c_ptr4
),
// 16
"6"
(
c_ptr5
),
// 17
"7"
(
b_ptr0
),
// 18
"8"
(
b_ptr1
),
// 19
"9"
(
b_ptr2
),
// 20
"10"
(
b_ptr3
),
// 21
"w"
(
a0
),
// 22
"w"
(
a1
),
// 23
"w"
(
a2
),
// 24
"w"
(
a3
),
// 25
"w"
(
a4
),
// 26
"w"
(
a5
)
// 27
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
w
=
(
width
>>
2
)
<<
2
;
}
#endif
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_a
,
stride_b
,
stride_c
,
GemmBlock
(
a_ptr
,
b_ptr
,
reg_height_tile
,
reg_K_tile
,
width
-
w
,
c_ptr
);
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_b
;
const
float
*
b_ptr
=
B
+
k
*
stride_b
;
float
*
c_ptr
=
C
+
h
*
stride_c
;
float
*
c_ptr
=
C
+
h
*
stride_c
;
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_a
,
stride_b
,
stride_c
,
GemmBlock
(
a_ptr
,
b_ptr
,
reg_height_tile
,
K
-
k
,
width
,
stride_a
,
stride_b
,
c_ptr
);
stride_c
,
c_ptr
);
}
}
}
}
if
(
h
<
height
)
{
if
(
h
<
height
)
{
index_t
remain_h
=
height
-
h
;
index_t
remain_h
=
height
-
h
;
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
reg_K_tile
;
k
+=
reg_K_tile
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
index_t
w
;
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
#if defined(__aarch64__)
GemmX84
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
GemmX84
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#else
GemmX44
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#endif
}
}
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
8
,
width
-
w
,
stride_a
,
stride_b
,
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
reg_K_tile
,
width
-
w
,
stride_a
,
stride_c
,
c_ptr
);
stride_
b
,
stride_
c
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
...
...
mace/ops/shape_test.cc
浏览文件 @
a025ac02
...
@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
...
@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
std
::
vector
<
int32_t
>
expected_input_shape
(
input_shape
.
begin
(),
std
::
vector
<
int32_t
>
expected_input_shape
(
input_shape
.
begin
(),
input_shape
.
end
());
input_shape
.
end
());
if
(
!
expected_input_shape
.
empty
())
{
if
(
!
expected_input_shape
.
empty
())
{
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
expected_input_shape
);
expected_input_shape
);
}
else
{
}
else
{
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{},
{
0
});
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{},
{
0
});
...
...
mace/ops/strided_slice_test.cc
浏览文件 @
a025ac02
...
@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
...
@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
const
std
::
vector
<
float
>
&
output
)
{
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
end_indices
);
end_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
input_shape
.
size
()},
strides
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
strides
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"Input"
)
.
Input
(
"Input"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录