Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
cd506756
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cd506756
编写于
6月 06, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Gemm transpose
上级
c0c0dfe5
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
685 addition
and
439 deletion
+685
-439
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+386
-289
mace/kernels/gemm.h
mace/kernels/gemm.h
+6
-2
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+40
-43
mace/kernels/matmul.h
mace/kernels/matmul.h
+40
-14
mace/kernels/opencl/buffer_to_image.cc
mace/kernels/opencl/buffer_to_image.cc
+5
-1
mace/kernels/opencl/helper.cc
mace/kernels/opencl/helper.cc
+16
-8
mace/kernels/opencl/image_to_buffer.cc
mace/kernels/opencl/image_to_buffer.cc
+5
-1
mace/kernels/opencl/matmul.cc
mace/kernels/opencl/matmul.cc
+18
-9
mace/kernels/opencl/winograd_transform.cc
mace/kernels/opencl/winograd_transform.cc
+2
-2
mace/ops/matmul.h
mace/ops/matmul.h
+25
-12
mace/ops/matmul_benchmark.cc
mace/ops/matmul_benchmark.cc
+58
-2
mace/ops/matmul_test.cc
mace/ops/matmul_test.cc
+81
-53
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+2
-2
mace/python/tools/memory_optimizer.py
mace/python/tools/memory_optimizer.py
+1
-1
未找到文件。
mace/kernels/gemm.cc
浏览文件 @
cd506756
...
@@ -12,18 +12,16 @@
...
@@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <algorithm>
#include <cstring>
#include <cstring>
#include "mace/core/tensor.h"
#include "mace/kernels/gemm.h"
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#include <arm_neon.h>
#endif
#endif
#include "mace/core/macros.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/logging.h"
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
#endif
...
@@ -37,13 +35,14 @@ inline void GemmBlock(const float *A,
...
@@ -37,13 +35,14 @@ inline void GemmBlock(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
const
index_t
stride_k
,
const
index_t
stride_a
,
const
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
C
)
{
float
*
C
)
{
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
width
;
++
j
)
{
for
(
int
j
=
0
;
j
<
width
;
++
j
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
C
[
i
*
stride_
w
+
j
]
+=
A
[
i
*
stride_k
+
k
]
*
B
[
k
*
stride_w
+
j
];
C
[
i
*
stride_
c
+
j
]
+=
A
[
i
*
stride_a
+
k
]
*
B
[
k
*
stride_b
+
j
];
}
}
}
}
}
}
...
@@ -75,8 +74,9 @@ inline void GemmBlock(const float *A,
...
@@ -75,8 +74,9 @@ inline void GemmBlock(const float *A,
inline
void
Gemm884
(
const
float
*
a_ptr
,
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
...
@@ -86,38 +86,38 @@ inline void Gemm884(const float *a_ptr,
...
@@ -86,38 +86,38 @@ inline void Gemm884(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
+
4
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
+
4
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
+
4
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
+
4
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
+
4
);
a14
=
vld1q_f32
(
a_ptr
+
7
*
stride_
k
);
a14
=
vld1q_f32
(
a_ptr
+
7
*
stride_
a
);
a15
=
vld1q_f32
(
a_ptr
+
7
*
stride_
k
+
4
);
a15
=
vld1q_f32
(
a_ptr
+
7
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
w
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_
w
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_
c
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_
w
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -140,25 +140,28 @@ inline void Gemm884(const float *a_ptr,
...
@@ -140,25 +140,28 @@ inline void Gemm884(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_
w
,
c3
);
vst1q_f32
(
c_ptr
+
3
*
stride_
c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_
w
,
c4
);
vst1q_f32
(
c_ptr
+
4
*
stride_
c
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_
w
,
c5
);
vst1q_f32
(
c_ptr
+
5
*
stride_
c
,
c5
);
vst1q_f32
(
c_ptr
+
6
*
stride_
w
,
c6
);
vst1q_f32
(
c_ptr
+
6
*
stride_
c
,
c6
);
vst1q_f32
(
c_ptr
+
7
*
stride_
w
,
c7
);
vst1q_f32
(
c_ptr
+
7
*
stride_
c
,
c7
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm184
(
const
float
*
a_ptr
,
inline
void
Gemm184
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
MACE_UNUSED
(
stride_k
);
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED
(
stride_a
);
MACE_UNUSED
(
stride_c
);
float32x4_t
a0
,
a1
;
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
c0
;
float32x4_t
c0
;
...
@@ -167,13 +170,13 @@ inline void Gemm184(const float *a_ptr,
...
@@ -167,13 +170,13 @@ inline void Gemm184(const float *a_ptr,
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
...
@@ -185,14 +188,15 @@ inline void Gemm184(const float *a_ptr,
...
@@ -185,14 +188,15 @@ inline void Gemm184(const float *a_ptr,
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
1
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
1
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm284
(
const
float
*
a_ptr
,
inline
void
Gemm284
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
;
float32x4_t
a0
,
a1
,
a2
,
a3
;
...
@@ -201,20 +205,20 @@ inline void Gemm284(const float *a_ptr,
...
@@ -201,20 +205,20 @@ inline void Gemm284(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -225,16 +229,17 @@ inline void Gemm284(const float *a_ptr,
...
@@ -225,16 +229,17 @@ inline void Gemm284(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
2
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
2
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm384
(
const
float
*
a_ptr
,
inline
void
Gemm384
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
...
@@ -243,23 +248,23 @@ inline void Gemm384(const float *a_ptr,
...
@@ -243,23 +248,23 @@ inline void Gemm384(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -272,17 +277,18 @@ inline void Gemm384(const float *a_ptr,
...
@@ -272,17 +277,18 @@ inline void Gemm384(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
3
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
3
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm484
(
const
float
*
a_ptr
,
inline
void
Gemm484
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
;
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
;
...
@@ -291,26 +297,26 @@ inline void Gemm484(const float *a_ptr,
...
@@ -291,26 +297,26 @@ inline void Gemm484(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -325,18 +331,19 @@ inline void Gemm484(const float *a_ptr,
...
@@ -325,18 +331,19 @@ inline void Gemm484(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_
w
,
c3
);
vst1q_f32
(
c_ptr
+
3
*
stride_
c
,
c3
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm584
(
const
float
*
a_ptr
,
inline
void
Gemm584
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
;
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
;
...
@@ -345,29 +352,29 @@ inline void Gemm584(const float *a_ptr,
...
@@ -345,29 +352,29 @@ inline void Gemm584(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
+
4
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -384,19 +391,20 @@ inline void Gemm584(const float *a_ptr,
...
@@ -384,19 +391,20 @@ inline void Gemm584(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_
w
,
c3
);
vst1q_f32
(
c_ptr
+
3
*
stride_
c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_
w
,
c4
);
vst1q_f32
(
c_ptr
+
4
*
stride_
c
,
c4
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
5
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
5
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm684
(
const
float
*
a_ptr
,
inline
void
Gemm684
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
;
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
;
...
@@ -405,32 +413,32 @@ inline void Gemm684(const float *a_ptr,
...
@@ -405,32 +413,32 @@ inline void Gemm684(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
+
4
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
+
4
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
w
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -449,21 +457,22 @@ inline void Gemm684(const float *a_ptr,
...
@@ -449,21 +457,22 @@ inline void Gemm684(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_
w
,
c3
);
vst1q_f32
(
c_ptr
+
3
*
stride_
c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_
w
,
c4
);
vst1q_f32
(
c_ptr
+
4
*
stride_
c
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_
w
,
c5
);
vst1q_f32
(
c_ptr
+
5
*
stride_
c
,
c5
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
6
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
6
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
Gemm784
(
const
float
*
a_ptr
,
inline
void
Gemm784
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
;
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
;
...
@@ -472,35 +481,35 @@ inline void Gemm784(const float *a_ptr,
...
@@ -472,35 +481,35 @@ inline void Gemm784(const float *a_ptr,
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
+
4
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
+
4
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
+
4
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
+
4
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
+
4
);
b0
=
vld1q_f32
(
b_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
w
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_
b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
w
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_
b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
w
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_
b
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
w
);
b4
=
vld1q_f32
(
b_ptr
+
4
*
stride_
b
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
w
);
b5
=
vld1q_f32
(
b_ptr
+
5
*
stride_
b
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
w
);
b6
=
vld1q_f32
(
b_ptr
+
6
*
stride_
b
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
w
);
b7
=
vld1q_f32
(
b_ptr
+
7
*
stride_
b
);
c0
=
vld1q_f32
(
c_ptr
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
w
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_
c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
w
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_
c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
w
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_
c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
w
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_
c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
w
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_
c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_
w
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_
c
);
#if defined(__aarch64__)
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
...
@@ -521,48 +530,49 @@ inline void Gemm784(const float *a_ptr,
...
@@ -521,48 +530,49 @@ inline void Gemm784(const float *a_ptr,
#endif
#endif
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_
w
,
c1
);
vst1q_f32
(
c_ptr
+
1
*
stride_
c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_
w
,
c2
);
vst1q_f32
(
c_ptr
+
2
*
stride_
c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_
w
,
c3
);
vst1q_f32
(
c_ptr
+
3
*
stride_
c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_
w
,
c4
);
vst1q_f32
(
c_ptr
+
4
*
stride_
c
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_
w
,
c5
);
vst1q_f32
(
c_ptr
+
5
*
stride_
c
,
c5
);
vst1q_f32
(
c_ptr
+
6
*
stride_
w
,
c6
);
vst1q_f32
(
c_ptr
+
6
*
stride_
c
,
c6
);
#else
#else
GemmBlock
(
a_ptr
,
b_ptr
,
7
,
8
,
4
,
stride_
k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
7
,
8
,
4
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
#endif
}
}
inline
void
GemmX84
(
const
float
*
a_ptr
,
inline
void
GemmX84
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_k
,
const
index_t
stride_a
,
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
,
float
*
c_ptr
,
int
row
)
{
int
row
)
{
switch
(
row
)
{
switch
(
row
)
{
case
1
:
case
1
:
Gemm184
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm184
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
2
:
case
2
:
Gemm284
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm284
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
3
:
case
3
:
Gemm384
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm384
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
4
:
case
4
:
Gemm484
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm484
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
5
:
case
5
:
Gemm584
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm584
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
6
:
case
6
:
Gemm684
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm684
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
7
:
case
7
:
Gemm784
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm784
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
case
8
:
case
8
:
Gemm884
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
break
;
default:
default:
MACE_NOT_IMPLEMENTED
;
MACE_NOT_IMPLEMENTED
;
...
@@ -574,14 +584,15 @@ inline void GemmTile(const float *A,
...
@@ -574,14 +584,15 @@ inline void GemmTile(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
const
index_t
stride_k
,
const
index_t
stride_a
,
const
index_t
stride_w
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
C
)
{
float
*
C
)
{
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
index_t
h
,
w
,
k
;
index_t
h
,
w
,
k
;
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_
k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_
a
+
k
);
#if defined(__aarch64__) && defined(__clang__)
#if defined(__aarch64__) && defined(__clang__)
int
nw
=
width
>>
2
;
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
if
(
nw
>
0
)
{
...
@@ -590,38 +601,38 @@ inline void GemmTile(const float *A,
...
@@ -590,38 +601,38 @@ inline void GemmTile(const float *A,
a14
,
a15
;
a14
,
a15
;
a0
=
vld1q_f32
(
a_ptr
);
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
k
+
4
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_
a
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
k
+
4
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_
a
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
k
+
4
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_
a
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
k
+
4
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_
a
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
k
+
4
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_
a
+
4
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
k
+
4
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_
a
+
4
);
a14
=
vld1q_f32
(
a_ptr
+
7
*
stride_
k
);
a14
=
vld1q_f32
(
a_ptr
+
7
*
stride_
a
);
a15
=
vld1q_f32
(
a_ptr
+
7
*
stride_
k
+
4
);
a15
=
vld1q_f32
(
a_ptr
+
7
*
stride_
a
+
4
);
const
float
*
b_ptr0
=
B
+
k
*
stride_
w
;
const
float
*
b_ptr0
=
B
+
k
*
stride_
b
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_
w
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_
b
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_
w
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_
b
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_
w
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_
b
;
const
float
*
b_ptr4
=
B
+
(
k
+
4
)
*
stride_
w
;
const
float
*
b_ptr4
=
B
+
(
k
+
4
)
*
stride_
b
;
const
float
*
b_ptr5
=
B
+
(
k
+
5
)
*
stride_
w
;
const
float
*
b_ptr5
=
B
+
(
k
+
5
)
*
stride_
b
;
const
float
*
b_ptr6
=
B
+
(
k
+
6
)
*
stride_
w
;
const
float
*
b_ptr6
=
B
+
(
k
+
6
)
*
stride_
b
;
const
float
*
b_ptr7
=
B
+
(
k
+
7
)
*
stride_
w
;
const
float
*
b_ptr7
=
B
+
(
k
+
7
)
*
stride_
b
;
float
*
c_ptr0
=
C
+
h
*
stride_
w
;
float
*
c_ptr0
=
C
+
h
*
stride_
c
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_
w
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_
c
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_
w
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_
c
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_
w
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_
c
;
float
*
c_ptr4
=
C
+
(
h
+
4
)
*
stride_
w
;
float
*
c_ptr4
=
C
+
(
h
+
4
)
*
stride_
c
;
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_
w
;
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_
c
;
float
*
c_ptr6
=
C
+
(
h
+
6
)
*
stride_
w
;
float
*
c_ptr6
=
C
+
(
h
+
6
)
*
stride_
c
;
float
*
c_ptr7
=
C
+
(
h
+
7
)
*
stride_
w
;
float
*
c_ptr7
=
C
+
(
h
+
7
)
*
stride_
c
;
asm
volatile
(
asm
volatile
(
"prfm pldl1keep, [%9, #128]
\n
"
"prfm pldl1keep, [%9, #128]
\n
"
...
@@ -824,53 +835,68 @@ inline void GemmTile(const float *A,
...
@@ -824,53 +835,68 @@ inline void GemmTile(const float *A,
}
}
#else // gcc || armv7a
#else // gcc || armv7a
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_
w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_
b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
c
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
#endif // clang && armv8a
#endif // clang && armv8a
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
const
float
*
b_ptr
=
B
+
k
*
stride_b
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_c
;
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
}
}
if
(
h
<
height
)
{
if
(
h
<
height
)
{
index_t
remain_h
=
height
-
h
;
index_t
remain_h
=
height
-
h
;
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_
k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_
a
+
k
);
index_t
w
;
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_
w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_
b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
c
+
w
);
GemmX84
(
a_ptr
,
b_ptr
,
stride_
k
,
stride_w
,
c_ptr
,
remain_h
);
GemmX84
(
a_ptr
,
b_ptr
,
stride_
a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
}
}
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_
w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_
b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_
c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
8
,
width
-
w
,
stride_
k
,
stride_w
,
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
8
,
width
-
w
,
stride_
a
,
stride_b
,
c_ptr
);
stride_c
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_
k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_
a
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_
w
;
const
float
*
b_ptr
=
B
+
k
*
stride_
b
;
float
*
c_ptr
=
C
+
h
*
stride_
w
;
float
*
c_ptr
=
C
+
h
*
stride_
c
;
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
K
-
k
,
width
,
stride_
k
,
stride_w
,
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
K
-
k
,
width
,
stride_
a
,
stride_b
,
c_ptr
);
stride_c
,
c_ptr
);
}
}
}
}
#else
#else
GemmBlock
(
A
,
B
,
height
,
K
,
width
,
stride_
k
,
stride_w
,
C
);
GemmBlock
(
A
,
B
,
height
,
K
,
width
,
stride_
a
,
stride_b
,
stride_c
,
C
);
#endif // MACE_ENABLE_NEON
#endif // MACE_ENABLE_NEON
}
}
void
Transpose
(
const
float
*
src
,
index_t
height
,
index_t
width
,
index_t
stride_w
,
float
*
dst
)
{
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
dst
[
w
*
height
+
h
]
=
src
[
h
*
stride_w
+
w
];
}
}
}
}
// namespace
}
// namespace
// A: height x K, B: K x width, C: height x width
// A: height x K, B: K x width, C: height x width
...
@@ -880,7 +906,9 @@ void Gemm(const float *A,
...
@@ -880,7 +906,9 @@ void Gemm(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
float
*
C
)
{
float
*
C
,
const
bool
transpose_a
,
const
bool
transpose_b
)
{
if
(
width
==
1
)
{
if
(
width
==
1
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
Gemv
(
A
+
b
*
height
*
K
,
B
+
b
*
K
,
1
,
K
,
height
,
C
+
b
*
height
);
Gemv
(
A
+
b
*
height
*
K
,
B
+
b
*
K
,
1
,
K
,
height
,
C
+
b
*
height
);
...
@@ -898,41 +926,77 @@ void Gemm(const float *A,
...
@@ -898,41 +926,77 @@ void Gemm(const float *A,
const
index_t
block_tile_height
=
RoundUpDiv
(
height
,
block_size
);
const
index_t
block_tile_height
=
RoundUpDiv
(
height
,
block_size
);
const
index_t
block_tile_width
=
RoundUpDiv
(
width
,
block_size
);
const
index_t
block_tile_width
=
RoundUpDiv
(
width
,
block_size
);
const
index_t
block_tile_k
=
RoundUpDiv
(
K
,
block_size
);
const
index_t
block_tile_k
=
RoundUpDiv
(
K
,
block_size
);
const
index_t
block_tile
[
3
]
=
{
block_tile_height
,
block_tile_width
,
block_tile_k
};
const
index_t
remain_height
=
height
%
block_size
;
const
index_t
remain_height
=
height
%
block_size
;
const
index_t
remain_width
=
width
%
block_size
;
const
index_t
remain_width
=
width
%
block_size
;
const
index_t
remain_k
=
K
%
block_size
;
const
index_t
remain_k
=
K
%
block_size
;
const
index_t
remain
[
3
]
=
{
remain_height
,
remain_width
,
remain_k
};
#pragma omp parallel for collapse(3)
#pragma omp parallel for collapse(3)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
bh
=
0
;
bh
<
block_tile
_height
;
++
bh
)
{
for
(
index_t
bh
=
0
;
bh
<
block_tile
[
0
]
;
++
bh
)
{
for
(
index_t
bw
=
0
;
bw
<
block_tile
_width
;
++
bw
)
{
for
(
index_t
bw
=
0
;
bw
<
block_tile
[
1
]
;
++
bw
)
{
const
float
*
a_base
=
A
+
n
*
height
*
K
;
const
float
*
a_base
=
A
+
n
*
height
*
K
;
const
float
*
b_base
=
B
+
n
*
K
*
width
;
const
float
*
b_base
=
B
+
n
*
K
*
width
;
float
*
c_base
=
C
+
n
*
height
*
width
;
float
*
c_base
=
C
+
n
*
height
*
width
;
const
index_t
ih_begin
=
bh
*
block_size
;
const
index_t
ih_begin
=
bh
*
block_size
;
const
index_t
ih_end
=
const
index_t
ih_end
=
bh
*
block_size
+
(
bh
==
block_tile_height
-
1
&&
remain_height
>
0
bh
*
block_size
+
?
remain_height
(
bh
==
block_tile
[
0
]
-
1
&&
remain
[
0
]
>
0
?
remain
[
0
]
:
block_size
);
:
block_size
);
const
index_t
iw_begin
=
bw
*
block_size
;
const
index_t
iw_begin
=
bw
*
block_size
;
const
index_t
iw_end
=
const
index_t
iw_end
=
bw
*
block_size
+
(
bw
==
block_tile_width
-
1
&&
remain_width
>
0
bw
*
block_size
+
?
remain_width
(
bw
==
block_tile
[
1
]
-
1
&&
remain
[
1
]
>
0
?
remain
[
1
]
:
block_size
);
:
block_size
);
for
(
index_t
bk
=
0
;
bk
<
block_tile
_k
;
++
bk
)
{
for
(
index_t
bk
=
0
;
bk
<
block_tile
[
2
]
;
++
bk
)
{
const
index_t
ik_begin
=
bk
*
block_size
;
const
index_t
ik_begin
=
bk
*
block_size
;
const
index_t
ik_end
=
const
index_t
ik_end
=
bk
*
block_size
+
bk
*
block_size
+
(
bk
==
block_tile
[
2
]
-
1
&&
remain
[
2
]
>
0
(
bk
==
block_tile_k
-
1
&&
remain_k
>
0
?
remain_k
:
block_size
);
?
remain
[
2
]
:
block_size
);
Tensor
trans_a
;
Tensor
trans_b
;
const
float
*
real_a
=
nullptr
;
const
float
*
real_b
=
nullptr
;
float
*
real_c
=
c_base
+
(
ih_begin
*
width
+
iw_begin
);
index_t
stride_a
;
index_t
stride_b
;
index_t
stride_c
=
width
;
if
(
transpose_a
)
{
trans_a
.
Resize
({
block_size
,
block_size
});
float
*
trans_a_data
=
trans_a
.
mutable_data
<
float
>
();
// A[K, H] -> A[H, K]
Transpose
(
a_base
+
(
ik_begin
*
height
+
ih_begin
),
ik_end
-
ik_begin
,
ih_end
-
ih_begin
,
height
,
trans_a_data
);
real_a
=
trans_a_data
;
stride_a
=
ik_end
-
ik_begin
;
}
else
{
real_a
=
a_base
+
(
ih_begin
*
K
+
ik_begin
);
stride_a
=
K
;
}
if
(
transpose_b
)
{
trans_b
.
Resize
({
block_size
,
block_size
});
float
*
trans_b_data
=
trans_b
.
mutable_data
<
float
>
();
// B[W, K] -> B[K, W]
Transpose
(
b_base
+
(
iw_begin
*
K
+
ik_begin
),
iw_end
-
iw_begin
,
ik_end
-
ik_begin
,
K
,
trans_b_data
);
real_b
=
trans_b_data
;
stride_b
=
iw_end
-
iw_begin
;
}
else
{
real_b
=
b_base
+
(
ik_begin
*
width
+
iw_begin
);
stride_b
=
width
;
}
// inside block:
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile
(
a_base
+
(
ih_begin
*
K
+
ik_begin
),
GemmTile
(
real_a
,
real_b
,
ih_end
-
ih_begin
,
ik_end
-
ik_begin
,
b_base
+
(
ik_begin
*
width
+
iw_begin
),
ih_end
-
ih_begin
,
iw_end
-
iw_begin
,
stride_a
,
stride_b
,
stride_c
,
real_c
);
ik_end
-
ik_begin
,
iw_end
-
iw_begin
,
K
,
width
,
c_base
+
(
ih_begin
*
width
+
iw_begin
));
}
// bk
}
// bk
}
// bw
}
// bw
}
// bh
}
// bh
...
@@ -946,14 +1010,47 @@ void GemmRef(const float *A,
...
@@ -946,14 +1010,47 @@ void GemmRef(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
float
*
C
)
{
float
*
C
,
const
bool
transpose_a
,
const
bool
transpose_b
)
{
memset
(
C
,
0
,
sizeof
(
float
)
*
batch
*
height
*
width
);
memset
(
C
,
0
,
sizeof
(
float
)
*
batch
*
height
*
width
);
Tensor
trans_a
;
Tensor
trans_b
;
float
*
trans_a_data
=
nullptr
;
float
*
trans_b_data
=
nullptr
;
if
(
transpose_a
)
{
trans_a
.
Resize
({
height
,
K
});
trans_a_data
=
trans_a
.
mutable_data
<
float
>
();
}
if
(
transpose_b
)
{
trans_b
.
Resize
({
K
,
width
});
trans_b_data
=
trans_b
.
mutable_data
<
float
>
();
}
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
const
float
*
real_a
=
nullptr
;
const
float
*
real_b
=
nullptr
;
float
*
real_c
=
C
+
b
*
height
*
width
;
if
(
transpose_a
)
{
// A[K, H] -> A[H, K]
Transpose
(
A
+
b
*
height
*
K
,
K
,
height
,
height
,
trans_a_data
);
real_a
=
trans_a_data
;
}
else
{
real_a
=
A
+
b
*
height
*
K
;
}
if
(
transpose_b
)
{
// B[W, K] -> B[K, W]
Transpose
(
B
+
b
*
width
*
K
,
width
,
K
,
K
,
trans_b_data
);
real_b
=
trans_b_data
;
}
else
{
real_b
=
B
+
b
*
width
*
K
;
}
for
(
index_t
i
=
0
;
i
<
height
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
height
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
width
;
++
j
)
{
for
(
index_t
j
=
0
;
j
<
width
;
++
j
)
{
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
C
[(
b
*
height
+
i
)
*
width
+
j
]
+=
real_c
[
i
*
width
+
j
]
+=
real_a
[
i
*
K
+
k
]
*
real_b
[
k
*
width
+
j
];
A
[(
b
*
height
+
i
)
*
K
+
k
]
*
B
[(
b
*
K
+
k
)
*
width
+
j
];
}
}
}
}
}
}
...
...
mace/kernels/gemm.h
浏览文件 @
cd506756
...
@@ -30,7 +30,9 @@ void Gemm(const float *A,
...
@@ -30,7 +30,9 @@ void Gemm(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
float
*
C
);
float
*
C
,
const
bool
transpose_a
=
false
,
const
bool
transpose_b
=
false
);
void
GemmRef
(
const
float
*
A
,
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
float
*
B
,
...
@@ -38,7 +40,9 @@ void GemmRef(const float *A,
...
@@ -38,7 +40,9 @@ void GemmRef(const float *A,
const
index_t
height
,
const
index_t
height
,
const
index_t
K
,
const
index_t
K
,
const
index_t
width
,
const
index_t
width
,
float
*
C
);
float
*
C
,
const
bool
transpose_a
=
false
,
const
bool
transpose_b
=
false
);
void
Gemv
(
const
float
*
m_ptr
,
void
Gemv
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
float
*
v_ptr
,
...
...
mace/kernels/gemm_test.cc
浏览文件 @
cd506756
...
@@ -13,17 +13,22 @@
...
@@ -13,17 +13,22 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <random>
#include <memory>
#include <memory>
#include <random>
#include "mace/kernels/gemm.h"
#include "mace/core/types.h"
#include "mace/core/types.h"
#include "mace/kernels/gemm.h"
namespace
mace
{
namespace
mace
{
namespace
{
namespace
{
void
GemmTest
(
index_t
batch
,
index_t
N
,
index_t
K
,
index_t
M
)
{
void
GemmTest
(
index_t
batch
,
index_t
N
,
index_t
K
,
index_t
M
,
bool
transpose_a
,
bool
transpose_b
)
{
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
batch
*
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
batch
*
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
batch
*
K
*
M
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
batch
*
K
*
M
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
batch
*
N
*
M
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
batch
*
N
*
M
]);
...
@@ -34,15 +39,13 @@ void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
...
@@ -34,15 +39,13 @@ void GemmTest(index_t batch, index_t N, index_t K, index_t M) {
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
A
.
get
(),
A
.
get
()
+
batch
*
N
*
K
,
std
::
generate
(
A
.
get
(),
A
.
get
()
+
batch
*
N
*
K
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
return
nd
(
gen
);
});
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
K
*
M
,
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
K
*
M
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
return
nd
(
gen
);
kernels
::
Gemm
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C
.
get
(),
transpose_a
,
}
);
transpose_b
);
kernels
::
Gemm
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C
.
get
());
kernels
::
Gemm
Ref
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C_ref
.
get
(),
transpose_a
,
kernels
::
GemmRef
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C_ref
.
get
()
);
transpose_b
);
for
(
int
i
=
0
;
i
<
batch
*
N
*
M
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch
*
N
*
M
;
++
i
)
{
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
...
@@ -59,14 +62,8 @@ void GemvTest(index_t batch, index_t N, index_t M) {
...
@@ -59,14 +62,8 @@ void GemvTest(index_t batch, index_t N, index_t M) {
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
A
.
get
(),
A
.
get
()
+
N
*
M
,
std
::
generate
(
A
.
get
(),
A
.
get
()
+
N
*
M
,
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
[
&
gen
,
&
nd
]
{
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
M
,
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
return
nd
(
gen
);
});
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
M
,
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C
.
get
());
kernels
::
Gemv
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C_ref
.
get
());
kernels
::
GemvRef
(
A
.
get
(),
B
.
get
(),
batch
,
M
,
N
,
C_ref
.
get
());
...
@@ -78,36 +75,36 @@ void GemvTest(index_t batch, index_t N, index_t M) {
...
@@ -78,36 +75,36 @@ void GemvTest(index_t batch, index_t N, index_t M) {
}
// namespace
}
// namespace
TEST
(
GEMMTest
,
AlignedWithoutBatch
)
{
TEST
(
GEMMTest
,
AlignedWithoutBatch
)
{
GemmTest
(
1
,
1
,
64
,
128
);
GemmTest
(
1
,
1
,
64
,
128
,
false
,
false
);
GemmTest
(
1
,
2
,
64
,
128
);
GemmTest
(
1
,
2
,
64
,
128
,
false
,
true
);
GemmTest
(
1
,
3
,
64
,
128
);
GemmTest
(
1
,
3
,
64
,
128
,
true
,
false
);
GemmTest
(
1
,
4
,
64
,
128
);
GemmTest
(
1
,
4
,
64
,
128
,
true
,
true
);
GemmTest
(
1
,
5
,
64
,
128
);
GemmTest
(
1
,
5
,
64
,
128
,
false
,
false
);
GemmTest
(
1
,
6
,
64
,
128
);
GemmTest
(
1
,
6
,
64
,
128
,
false
,
true
);
GemmTest
(
1
,
7
,
64
,
128
);
GemmTest
(
1
,
7
,
64
,
128
,
true
,
false
);
GemmTest
(
1
,
17
,
64
,
128
);
GemmTest
(
1
,
17
,
64
,
128
,
true
,
true
);
}
}
TEST
(
GEMMTest
,
UnalignedWithoutBatch
)
{
TEST
(
GEMMTest
,
UnalignedWithoutBatch
)
{
GemmTest
(
1
,
1
,
63
,
127
);
GemmTest
(
1
,
1
,
63
,
127
,
false
,
false
);
GemmTest
(
1
,
2
,
63
,
127
);
GemmTest
(
1
,
2
,
63
,
127
,
false
,
true
);
GemmTest
(
1
,
3
,
63
,
127
);
GemmTest
(
1
,
3
,
63
,
127
,
true
,
false
);
GemmTest
(
1
,
4
,
63
,
127
);
GemmTest
(
1
,
4
,
63
,
127
,
true
,
true
);
GemmTest
(
1
,
5
,
63
,
127
);
GemmTest
(
1
,
5
,
63
,
127
,
false
,
false
);
GemmTest
(
1
,
6
,
63
,
127
);
GemmTest
(
1
,
6
,
63
,
127
,
false
,
true
);
GemmTest
(
1
,
7
,
63
,
127
);
GemmTest
(
1
,
7
,
63
,
127
,
true
,
false
);
GemmTest
(
1
,
17
,
63
,
127
);
GemmTest
(
1
,
17
,
63
,
127
,
true
,
true
);
}
}
TEST
(
GEMMTest
,
UnalignedWithBatch
)
{
TEST
(
GEMMTest
,
UnalignedWithBatch
)
{
GemmTest
(
3
,
1
,
63
,
127
);
GemmTest
(
3
,
1
,
63
,
127
,
false
,
false
);
GemmTest
(
3
,
2
,
63
,
127
);
GemmTest
(
3
,
2
,
63
,
127
,
false
,
true
);
GemmTest
(
3
,
3
,
63
,
127
);
GemmTest
(
3
,
3
,
63
,
127
,
true
,
false
);
GemmTest
(
3
,
4
,
63
,
127
);
GemmTest
(
3
,
4
,
63
,
127
,
true
,
true
);
GemmTest
(
3
,
5
,
63
,
127
);
GemmTest
(
3
,
5
,
63
,
127
,
false
,
false
);
GemmTest
(
3
,
6
,
63
,
127
);
GemmTest
(
3
,
6
,
63
,
127
,
false
,
true
);
GemmTest
(
3
,
7
,
63
,
127
);
GemmTest
(
3
,
7
,
63
,
127
,
true
,
false
);
GemmTest
(
3
,
17
,
63
,
127
);
GemmTest
(
3
,
17
,
63
,
127
,
true
,
true
);
}
}
TEST
(
GEMMTest
,
gemv
)
{
TEST
(
GEMMTest
,
gemv
)
{
...
...
mace/kernels/matmul.h
浏览文件 @
cd506756
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#endif
#endif
#include <algorithm>
#include <algorithm>
#include <utility>
#include <functional>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -36,14 +38,39 @@
...
@@ -36,14 +38,39 @@
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
template
<
DeviceType
D
,
typename
T
>
template
<
DeviceType
D
,
typename
T
>
struct
MatMulFunctor
{
struct
MatMulFunctor
{
MaceStatus
operator
()(
const
Tensor
*
A
,
MaceStatus
operator
()(
const
Tensor
*
A
,
const
Tensor
*
B
,
const
Tensor
*
B
,
Tensor
*
C
,
Tensor
*
C
,
StatsFuture
*
future
)
{
bool
transpose_a
,
bool
transpose_b
,
StatsFuture
*
future
)
{
MACE_UNUSED
(
future
);
MACE_UNUSED
(
future
);
std
::
vector
<
index_t
>
c_shape
=
{
A
->
dim
(
0
),
A
->
dim
(
1
),
B
->
dim
(
2
),
1
};
index_t
batch
;
index_t
height
;
index_t
K
;
index_t
width
;
index_t
rank
=
A
->
dim_size
();
height
=
A
->
dim
(
rank
-
2
);
K
=
A
->
dim
(
rank
-
1
);
if
(
transpose_a
)
{
std
::
swap
(
height
,
K
);
}
if
(
transpose_b
)
{
width
=
B
->
dim
(
rank
-
2
);
}
else
{
width
=
B
->
dim
(
rank
-
1
);
}
batch
=
std
::
accumulate
(
A
->
shape
().
begin
(),
A
->
shape
().
end
()
-
2
,
1
,
std
::
multiplies
<
index_t
>
());
std
::
vector
<
index_t
>
c_shape
=
A
->
shape
();
c_shape
[
rank
-
2
]
=
height
;
c_shape
[
rank
-
1
]
=
width
;
MACE_RETURN_IF_ERROR
(
C
->
Resize
(
c_shape
));
MACE_RETURN_IF_ERROR
(
C
->
Resize
(
c_shape
));
Tensor
::
MappingGuard
guarda
(
A
);
Tensor
::
MappingGuard
guarda
(
A
);
...
@@ -53,28 +80,27 @@ struct MatMulFunctor {
...
@@ -53,28 +80,27 @@ struct MatMulFunctor {
const
T
*
b_ptr_base
=
B
->
data
<
T
>
();
const
T
*
b_ptr_base
=
B
->
data
<
T
>
();
T
*
c_ptr_base
=
C
->
mutable_data
<
T
>
();
T
*
c_ptr_base
=
C
->
mutable_data
<
T
>
();
const
index_t
batch
=
C
->
dim
(
0
);
const
index_t
height
=
C
->
dim
(
1
);
const
index_t
width
=
C
->
dim
(
2
);
const
index_t
K
=
A
->
dim
(
2
);
// It is better to use large block size if it fits for fast cache.
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
// the block size should be sqrt(32k / sizeof(T) / 3).
memset
(
c_ptr_base
,
0
,
batch
*
height
*
width
*
sizeof
(
T
));
memset
(
c_ptr_base
,
0
,
batch
*
height
*
width
*
sizeof
(
T
));
Gemm
(
a_ptr_base
,
b_ptr_base
,
batch
,
height
,
K
,
width
,
c_ptr_base
);
Gemm
(
a_ptr_base
,
b_ptr_base
,
batch
,
height
,
K
,
width
,
c_ptr_base
,
transpose_a
,
transpose_b
);
return
MACE_SUCCESS
;
return
MACE_SUCCESS
;
}
}
};
};
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
template
<
typename
T
>
template
<
typename
T
>
struct
MatMulFunctor
<
DeviceType
::
GPU
,
T
>
{
struct
MatMulFunctor
<
DeviceType
::
GPU
,
T
>
{
MaceStatus
operator
()(
const
Tensor
*
A
,
MaceStatus
operator
()(
const
Tensor
*
A
,
const
Tensor
*
B
,
const
Tensor
*
B
,
Tensor
*
C
,
Tensor
*
C
,
StatsFuture
*
future
);
bool
transpose_a
,
bool
transpose_b
,
StatsFuture
*
future
);
cl
::
Kernel
kernel_
;
cl
::
Kernel
kernel_
;
uint32_t
kwg_size_
;
uint32_t
kwg_size_
;
...
...
mace/kernels/opencl/buffer_to_image.cc
浏览文件 @
cd506756
...
@@ -134,7 +134,11 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
...
@@ -134,7 +134,11 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
}
else
{
}
else
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
1
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
1
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
2
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
2
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
3
)));
if
(
buffer
->
dim_size
()
<
4
)
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
1
));
}
else
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
3
)));
}
}
}
b2f_kernel
.
setArg
(
idx
++
,
*
(
image
->
opencl_image
()));
b2f_kernel
.
setArg
(
idx
++
,
*
(
image
->
opencl_image
()));
...
...
mace/kernels/opencl/helper.cc
浏览文件 @
cd506756
...
@@ -76,19 +76,27 @@ void CalWinogradFilterImageShape(
...
@@ -76,19 +76,27 @@ void CalWinogradFilterImageShape(
// [W * C, N * RoundUp<4>(H)]
// [W * C, N * RoundUp<4>(H)]
void
CalInOutHeightImageShape
(
const
std
::
vector
<
index_t
>
&
shape
,
/* NHWC */
void
CalInOutHeightImageShape
(
const
std
::
vector
<
index_t
>
&
shape
,
/* NHWC */
std
::
vector
<
size_t
>
*
image_shape
)
{
std
::
vector
<
size_t
>
*
image_shape
)
{
MACE_CHECK
(
shape
.
size
()
==
4
);
std
::
vector
<
index_t
>
padded_shape
=
shape
;
while
(
padded_shape
.
size
()
<
4
)
{
padded_shape
.
push_back
(
1
);
}
MACE_CHECK
(
padded_shape
.
size
()
==
4
);
image_shape
->
resize
(
2
);
image_shape
->
resize
(
2
);
(
*
image_shape
)[
0
]
=
shape
[
2
]
*
shape
[
3
];
(
*
image_shape
)[
0
]
=
padded_shape
[
2
]
*
padded_
shape
[
3
];
(
*
image_shape
)[
1
]
=
shape
[
0
]
*
RoundUpDiv4
(
shape
[
1
]);
(
*
image_shape
)[
1
]
=
padded_shape
[
0
]
*
RoundUpDiv4
(
padded_
shape
[
1
]);
}
}
// [RoundUp<4>(W) * C, N * H]
// [RoundUp<4>(W) * C, N * H]
void
CalInOutWidthImageShape
(
const
std
::
vector
<
index_t
>
&
shape
,
/* NHWC */
void
CalInOutWidthImageShape
(
const
std
::
vector
<
index_t
>
&
shape
,
/* NHWC */
std
::
vector
<
size_t
>
*
image_shape
)
{
std
::
vector
<
size_t
>
*
image_shape
)
{
MACE_CHECK
(
shape
.
size
()
==
4
);
std
::
vector
<
index_t
>
padded_shape
=
shape
;
while
(
padded_shape
.
size
()
<
4
)
{
padded_shape
.
push_back
(
1
);
}
MACE_CHECK
(
padded_shape
.
size
()
==
4
);
image_shape
->
resize
(
2
);
image_shape
->
resize
(
2
);
(
*
image_shape
)[
0
]
=
RoundUpDiv4
(
shape
[
2
])
*
shape
[
3
];
(
*
image_shape
)[
0
]
=
RoundUpDiv4
(
padded_shape
[
2
])
*
padded_
shape
[
3
];
(
*
image_shape
)[
1
]
=
shape
[
0
]
*
shape
[
1
];
(
*
image_shape
)[
1
]
=
padded_shape
[
0
]
*
padded_
shape
[
1
];
}
}
// [Ic * H * W, (Oc + 3) / 4]
// [Ic * H * W, (Oc + 3) / 4]
...
@@ -150,10 +158,10 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
...
@@ -150,10 +158,10 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
std
::
vector
<
index_t
>
CalWinogradShape
(
const
std
::
vector
<
index_t
>
&
shape
,
std
::
vector
<
index_t
>
CalWinogradShape
(
const
std
::
vector
<
index_t
>
&
shape
,
const
BufferType
type
)
{
const
BufferType
type
)
{
if
(
type
==
WINOGRAD_FILTER
)
{
if
(
type
==
WINOGRAD_FILTER
)
{
return
{
16
,
shape
[
0
],
shape
[
1
]
,
1
};
return
{
16
,
shape
[
0
],
shape
[
1
]};
}
else
if
(
type
==
IN_OUT_HEIGHT
)
{
}
else
if
(
type
==
IN_OUT_HEIGHT
)
{
index_t
out_width
=
shape
[
0
]
*
((
shape
[
1
]
-
1
)
/
2
)
*
((
shape
[
2
]
-
1
)
/
2
);
index_t
out_width
=
shape
[
0
]
*
((
shape
[
1
]
-
1
)
/
2
)
*
((
shape
[
2
]
-
1
)
/
2
);
return
{
16
,
shape
[
3
],
out_width
,
1
};
return
{
16
,
shape
[
3
],
out_width
};
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Mace not supported yet."
;
LOG
(
FATAL
)
<<
"Mace not supported yet."
;
return
std
::
vector
<
index_t
>
();
return
std
::
vector
<
index_t
>
();
...
...
mace/kernels/opencl/image_to_buffer.cc
浏览文件 @
cd506756
...
@@ -122,7 +122,11 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
...
@@ -122,7 +122,11 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
}
else
{
}
else
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
1
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
1
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
2
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
2
)));
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
3
)));
if
(
buffer
->
dim_size
()
<
4
)
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
1
));
}
else
{
b2f_kernel
.
setArg
(
idx
++
,
static_cast
<
uint32_t
>
(
buffer
->
dim
(
3
)));
}
}
}
b2f_kernel
.
setArg
(
idx
++
,
*
(
image
->
opencl_image
()));
b2f_kernel
.
setArg
(
idx
++
,
*
(
image
->
opencl_image
()));
...
...
mace/kernels/opencl/matmul.cc
浏览文件 @
cd506756
...
@@ -24,17 +24,27 @@ template <typename T>
...
@@ -24,17 +24,27 @@ template <typename T>
MaceStatus
MatMulFunctor
<
DeviceType
::
GPU
,
T
>::
operator
()(
const
Tensor
*
A
,
MaceStatus
MatMulFunctor
<
DeviceType
::
GPU
,
T
>::
operator
()(
const
Tensor
*
A
,
const
Tensor
*
B
,
const
Tensor
*
B
,
Tensor
*
C
,
Tensor
*
C
,
bool
transpose_a
,
bool
transpose_b
,
StatsFuture
*
future
)
{
StatsFuture
*
future
)
{
MACE_UNUSED
(
future
);
MACE_UNUSED
(
future
);
std
::
vector
<
index_t
>
c_shape
=
{
A
->
dim
(
0
),
A
->
dim
(
1
),
B
->
dim
(
2
),
1
};
MACE_CHECK
(
!
transpose_a
&&
!
transpose_b
,
"GPU does not support transpose matmul"
);
index_t
rank
=
A
->
dim_size
();
index_t
height
=
A
->
dim
(
rank
-
2
);
index_t
K
=
A
->
dim
(
rank
-
1
);
index_t
width
=
B
->
dim
(
rank
-
1
);
index_t
batch
=
std
::
accumulate
(
A
->
shape
().
begin
(),
A
->
shape
().
end
()
-
2
,
1
,
std
::
multiplies
<
index_t
>
());
std
::
vector
<
index_t
>
c_shape
=
A
->
shape
();
c_shape
[
rank
-
2
]
=
height
;
c_shape
[
rank
-
1
]
=
width
;
std
::
vector
<
size_t
>
c_image_shape
;
std
::
vector
<
size_t
>
c_image_shape
;
CalImage2DShape
(
c_shape
,
BufferType
::
IN_OUT_HEIGHT
,
&
c_image_shape
);
CalImage2DShape
(
c_shape
,
BufferType
::
IN_OUT_HEIGHT
,
&
c_image_shape
);
MACE_RETURN_IF_ERROR
(
C
->
ResizeImage
(
c_shape
,
c_image_shape
));
MACE_RETURN_IF_ERROR
(
C
->
ResizeImage
(
c_shape
,
c_image_shape
));
const
index_t
batch
=
C
->
dim
(
0
);
const
index_t
height
=
C
->
dim
(
1
);
const
index_t
width
=
C
->
dim
(
2
);
const
index_t
height_blocks
=
RoundUpDiv4
(
height
);
const
index_t
height_blocks
=
RoundUpDiv4
(
height
);
const
index_t
width_blocks
=
RoundUpDiv4
(
width
);
const
index_t
width_blocks
=
RoundUpDiv4
(
width
);
const
uint32_t
gws
[
2
]
=
{
const
uint32_t
gws
[
2
]
=
{
...
@@ -82,13 +92,12 @@ MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
...
@@ -82,13 +92,12 @@ MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
kernel_
.
setArg
(
idx
++
,
*
(
C
->
opencl_image
()));
kernel_
.
setArg
(
idx
++
,
*
(
C
->
opencl_image
()));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
height
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
height
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
width
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
width
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
A
->
dim
(
2
)
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
K
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
height_blocks
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
height_blocks
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
RoundUpDiv4
(
A
->
dim
(
2
)
)));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int
>
(
RoundUpDiv4
(
K
)));
const
std
::
vector
<
uint32_t
>
lws
=
{
kwg_size_
/
64
,
64
,
0
};
const
std
::
vector
<
uint32_t
>
lws
=
{
kwg_size_
/
64
,
64
,
0
};
std
::
string
tuning_key
=
Concat
(
"matmul_opencl_kernel"
,
C
->
dim
(
0
),
C
->
dim
(
1
),
std
::
string
tuning_key
=
Concat
(
"matmul_opencl_kernel"
,
batch
,
height
,
width
);
C
->
dim
(
2
),
C
->
dim
(
3
));
TuningOrRun2DKernel
(
kernel_
,
tuning_key
,
gws
,
lws
,
future
);
TuningOrRun2DKernel
(
kernel_
,
tuning_key
,
gws
,
lws
,
future
);
if
(
runtime
->
IsOutOfRangeCheckEnabled
())
{
if
(
runtime
->
IsOutOfRangeCheckEnabled
())
{
...
...
mace/kernels/opencl/winograd_transform.cc
浏览文件 @
cd506756
...
@@ -74,7 +74,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
...
@@ -74,7 +74,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
static_cast
<
uint32_t
>
(
RoundUpDiv4
(
input_tensor
->
dim
(
3
)))};
static_cast
<
uint32_t
>
(
RoundUpDiv4
(
input_tensor
->
dim
(
3
)))};
if
(
!
IsVecEqual
(
input_shape_
,
input_tensor
->
shape
()))
{
if
(
!
IsVecEqual
(
input_shape_
,
input_tensor
->
shape
()))
{
output_shape
=
{
16
,
input_tensor
->
dim
(
3
),
out_width
,
1
};
output_shape
=
{
16
,
input_tensor
->
dim
(
3
),
out_width
};
std
::
vector
<
size_t
>
image_shape
;
std
::
vector
<
size_t
>
image_shape
;
CalImage2DShape
(
output_shape
,
BufferType
::
IN_OUT_HEIGHT
,
&
image_shape
);
CalImage2DShape
(
output_shape
,
BufferType
::
IN_OUT_HEIGHT
,
&
image_shape
);
MACE_RETURN_IF_ERROR
(
output_tensor
->
ResizeImage
(
output_shape
,
image_shape
));
MACE_RETURN_IF_ERROR
(
output_tensor
->
ResizeImage
(
output_shape
,
image_shape
));
...
@@ -104,7 +104,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
...
@@ -104,7 +104,7 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
const
std
::
vector
<
uint32_t
>
lws
=
{
kwg_size_
/
8
,
8
,
0
};
const
std
::
vector
<
uint32_t
>
lws
=
{
kwg_size_
/
8
,
8
,
0
};
std
::
string
tuning_key
=
Concat
(
"winograd_transform_kernel"
,
std
::
string
tuning_key
=
Concat
(
"winograd_transform_kernel"
,
output_tensor
->
dim
(
0
),
output_tensor
->
dim
(
1
),
output_tensor
->
dim
(
0
),
output_tensor
->
dim
(
1
),
output_tensor
->
dim
(
2
)
,
output_tensor
->
dim
(
3
)
);
output_tensor
->
dim
(
2
));
TuningOrRun2DKernel
(
kernel_
,
tuning_key
,
gws
,
lws
,
future
);
TuningOrRun2DKernel
(
kernel_
,
tuning_key
,
gws
,
lws
,
future
);
if
(
runtime
->
IsOutOfRangeCheckEnabled
())
{
if
(
runtime
->
IsOutOfRangeCheckEnabled
())
{
...
...
mace/ops/matmul.h
浏览文件 @
cd506756
...
@@ -25,24 +25,37 @@ template <DeviceType D, class T>
...
@@ -25,24 +25,37 @@ template <DeviceType D, class T>
class
MatMulOp
:
public
Operator
<
D
,
T
>
{
class
MatMulOp
:
public
Operator
<
D
,
T
>
{
public:
public:
MatMulOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
MatMulOp
(
const
OperatorDef
&
operator_def
,
Workspace
*
ws
)
:
Operator
<
D
,
T
>
(
operator_def
,
ws
)
{}
:
Operator
<
D
,
T
>
(
operator_def
,
ws
),
transpose_a_
(
OperatorBase
::
GetOptionalArg
<
bool
>
(
"transpose_a"
,
false
)),
transpose_b_
(
OperatorBase
::
GetOptionalArg
<
bool
>
(
"transpose_b"
,
false
))
{
}
MaceStatus
Run
(
StatsFuture
*
future
)
override
{
MaceStatus
Run
(
StatsFuture
*
future
)
override
{
const
Tensor
*
A
=
this
->
Input
(
0
);
const
Tensor
*
A
=
this
->
Input
(
INPUT_A
);
const
Tensor
*
B
=
this
->
Input
(
1
);
const
Tensor
*
B
=
this
->
Input
(
INPUT_B
);
Tensor
*
C
=
this
->
Output
(
0
);
Tensor
*
C
=
this
->
Output
(
OUTPUT
);
MACE_CHECK
(
A
->
dim_size
()
==
4
&&
4
==
B
->
dim_size
())
MACE_CHECK
(
A
->
dim_size
()
==
B
->
dim_size
()
&&
A
->
dim_size
()
>=
2
,
<<
"The dimension of A and B should be 4"
;
"rank(A) should be equal to rank(B), rank should be greater "
MACE_CHECK
(
A
->
dim
(
0
)
==
B
->
dim
(
0
))
<<
"A and B must have same batch size"
;
"than or equal to 2"
);
MACE_CHECK
(
A
->
dim
(
2
)
==
B
->
dim
(
1
))
index_t
rank
=
A
->
dim_size
();
<<
"the number of A's column "
<<
A
->
dim
(
2
)
for
(
index_t
i
=
0
;
i
<
rank
-
2
;
++
i
)
{
<<
" must be equal to B's row "
<<
B
->
dim
(
1
);
MACE_CHECK
(
A
->
dim
(
i
)
==
B
->
dim
(
i
),
"batch dimensions are not equal"
);
}
return
functor_
(
A
,
B
,
C
,
future
);
index_t
ak
=
transpose_a_
?
A
->
dim
(
rank
-
2
)
:
A
->
dim
(
rank
-
1
);
index_t
bk
=
transpose_b_
?
B
->
dim
(
rank
-
1
)
:
B
->
dim
(
rank
-
2
);
MACE_CHECK
(
ak
==
bk
,
"the number of A's column "
,
ak
,
" must be equal to B's row "
,
bk
);
return
functor_
(
A
,
B
,
C
,
transpose_a_
,
transpose_b_
,
future
);
}
}
private:
private:
MACE_OP_INPUT_TAGS
(
INPUT_A
,
INPUT_B
);
MACE_OP_OUTPUT_TAGS
(
OUTPUT
);
kernels
::
MatMulFunctor
<
D
,
T
>
functor_
;
kernels
::
MatMulFunctor
<
D
,
T
>
functor_
;
bool
transpose_a_
;
bool
transpose_b_
;
};
};
}
// namespace ops
}
// namespace ops
...
...
mace/ops/matmul_benchmark.cc
浏览文件 @
cd506756
...
@@ -31,8 +31,8 @@ void MatMulBenchmark(
...
@@ -31,8 +31,8 @@ void MatMulBenchmark(
OpsTestNet
net
;
OpsTestNet
net
;
// Add input data
// Add input data
net
.
AddRandomInput
<
D
,
float
>
(
"A"
,
{
batch
,
height
,
channels
,
1
});
net
.
AddRandomInput
<
D
,
float
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
float
>
(
"B"
,
{
batch
,
channels
,
out_width
,
1
});
net
.
AddRandomInput
<
D
,
float
>
(
"B"
,
{
batch
,
channels
,
out_width
});
if
(
D
==
DeviceType
::
GPU
)
{
if
(
D
==
DeviceType
::
GPU
)
{
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
...
@@ -65,6 +65,41 @@ void MatMulBenchmark(
...
@@ -65,6 +65,41 @@ void MatMulBenchmark(
}
}
net
.
Sync
();
net
.
Sync
();
}
}
template
<
DeviceType
D
,
typename
T
>
void
MatMulTransposeBenchmark
(
int
iters
,
int
batch
,
int
height
,
int
channels
,
int
out_width
)
{
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
// Add input data
net
.
AddRandomInput
<
D
,
float
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
float
>
(
"B"
,
{
batch
,
out_width
,
channels
});
if
(
D
==
DeviceType
::
CPU
)
{
OpDefBuilder
(
"MatMul"
,
"MatMulBM"
)
.
Input
(
"A"
)
.
Input
(
"B"
)
.
AddIntArg
(
"transpose_b"
,
1
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
}
else
{
MACE_NOT_IMPLEMENTED
;
}
// Warm-up
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
}
}
// namespace
}
// namespace
#define MACE_BM_MATMUL_MACRO(N, H, C, W, TYPE, DEVICE) \
#define MACE_BM_MATMUL_MACRO(N, H, C, W, TYPE, DEVICE) \
...
@@ -83,6 +118,20 @@ void MatMulBenchmark(
...
@@ -83,6 +118,20 @@ void MatMulBenchmark(
MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \
MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU);
MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU);
#define MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, TYPE, DEVICE) \
static void MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * (C * H + H * W); \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
MatMulTransposeBenchmark<DEVICE, TYPE>(iters, N, H, C, W); \
} \
MACE_BENCHMARK(MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_MATMUL_TRANPOSE(N, H, C, W) \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU);
MACE_BM_MATMUL
(
16
,
32
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
3969
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
3969
);
...
@@ -90,6 +139,13 @@ MACE_BM_MATMUL(16, 128, 128, 49);
...
@@ -90,6 +139,13 @@ MACE_BM_MATMUL(16, 128, 128, 49);
MACE_BM_MATMUL
(
16
,
128
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
3969
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
3969
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
32
,
128
,
49
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
32
,
128
,
961
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
32
,
128
,
3969
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
128
,
128
,
49
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
128
,
128
,
961
);
MACE_BM_MATMUL_TRANPOSE
(
16
,
128
,
128
,
3969
);
}
// namespace test
}
// namespace test
}
// namespace ops
}
// namespace ops
}
// namespace mace
}
// namespace mace
mace/ops/matmul_test.cc
浏览文件 @
cd506756
...
@@ -72,46 +72,46 @@ void Simple(const std::vector<index_t> &A_shape,
...
@@ -72,46 +72,46 @@ void Simple(const std::vector<index_t> &A_shape,
}
// namespace
}
// namespace
TEST_F
(
MatMulOpTest
,
SimpleCPU
)
{
TEST_F
(
MatMulOpTest
,
SimpleCPU
)
{
Simple
<
DeviceType
::
CPU
>
({
1
,
2
,
3
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
3
,
2
,
1
},
Simple
<
DeviceType
::
CPU
>
({
1
,
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
,
2
,
1
},
{
22
,
28
,
49
,
64
});
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
,
2
},
{
22
,
28
,
49
,
64
});
Simple
<
DeviceType
::
CPU
>
(
Simple
<
DeviceType
::
CPU
>
(
{
1
,
5
,
5
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
{
1
,
5
,
5
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
{
1
,
5
,
5
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
{
1
,
5
,
5
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
{
1
,
5
,
5
,
1
},
{
215
,
230
,
245
,
260
,
275
,
490
,
530
,
570
,
610
,
{
1
,
5
,
5
},
{
215
,
230
,
245
,
260
,
275
,
490
,
530
,
570
,
610
,
650
,
765
,
830
,
895
,
960
,
1025
,
1040
,
1130
,
1220
,
650
,
765
,
830
,
895
,
960
,
1025
,
1040
,
1130
,
1220
,
1310
,
1400
,
1315
,
1430
,
1545
,
1660
,
1775
});
1310
,
1400
,
1315
,
1430
,
1545
,
1660
,
1775
});
}
}
TEST_F
(
MatMulOpTest
,
SimpleCPUWithBatch
)
{
TEST_F
(
MatMulOpTest
,
SimpleCPUWithBatch
)
{
Simple
<
DeviceType
::
CPU
>
({
2
,
2
,
3
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
Simple
<
DeviceType
::
CPU
>
({
2
,
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
2
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
2
,
2
,
1
},
{
22
,
28
,
49
,
64
,
22
,
28
,
49
,
64
});
{
2
,
2
,
2
},
{
22
,
28
,
49
,
64
,
22
,
28
,
49
,
64
});
}
}
TEST_F
(
MatMulOpTest
,
SimpleOPENCL
)
{
TEST_F
(
MatMulOpTest
,
SimpleOPENCL
)
{
Simple
<
DeviceType
::
GPU
>
({
1
,
2
,
3
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
3
,
2
,
1
},
Simple
<
DeviceType
::
GPU
>
({
1
,
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
,
2
,
1
},
{
22
,
28
,
49
,
64
});
{
1
,
2
,
3
,
4
,
5
,
6
},
{
1
,
2
,
2
},
{
22
,
28
,
49
,
64
});
Simple
<
DeviceType
::
GPU
>
(
Simple
<
DeviceType
::
GPU
>
(
{
1
,
5
,
5
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
{
1
,
5
,
5
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
{
1
,
5
,
5
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
{
1
,
5
,
5
},
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
},
{
1
,
5
,
5
,
1
},
{
215
,
230
,
245
,
260
,
275
,
490
,
530
,
570
,
610
,
{
1
,
5
,
5
},
{
215
,
230
,
245
,
260
,
275
,
490
,
530
,
570
,
610
,
650
,
765
,
830
,
895
,
960
,
1025
,
1040
,
1130
,
1220
,
650
,
765
,
830
,
895
,
960
,
1025
,
1040
,
1130
,
1220
,
1310
,
1400
,
1315
,
1430
,
1545
,
1660
,
1775
});
1310
,
1400
,
1315
,
1430
,
1545
,
1660
,
1775
});
}
}
TEST_F
(
MatMulOpTest
,
SimpleGPUWithBatch
)
{
TEST_F
(
MatMulOpTest
,
SimpleGPUWithBatch
)
{
Simple
<
DeviceType
::
CPU
>
({
2
,
2
,
3
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
Simple
<
DeviceType
::
CPU
>
({
2
,
2
,
3
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
2
,
1
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
3
,
2
},
{
1
,
2
,
3
,
4
,
5
,
6
,
1
,
2
,
3
,
4
,
5
,
6
},
{
2
,
2
,
2
,
1
},
{
22
,
28
,
49
,
64
,
22
,
28
,
49
,
64
});
{
2
,
2
,
2
},
{
22
,
28
,
49
,
64
,
22
,
28
,
49
,
64
});
}
}
namespace
{
namespace
{
template
<
typename
T
>
template
<
typename
T
>
void
Complex
(
const
index_t
batch
,
void
Complex
(
const
std
::
vector
<
index_t
>
&
batch
,
const
index_t
height
,
const
index_t
height
,
const
index_t
channels
,
const
index_t
channels
,
const
index_t
out_width
)
{
const
index_t
out_width
)
{
...
@@ -119,23 +119,14 @@ void Complex(const index_t batch,
...
@@ -119,23 +119,14 @@ void Complex(const index_t batch,
// Construct graph
// Construct graph
OpsTestNet
net
;
OpsTestNet
net
;
OpDefBuilder
(
"MatMul"
,
"MatMulTest"
)
.
Input
(
"A"
)
.
Input
(
"B"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
// Add input data
// Add input data
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
"A"
,
{
batch
,
height
,
channels
,
1
});
index_t
batch_count
=
std
::
accumulate
(
batch
.
begin
(),
batch
.
end
(),
1
,
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
"B"
,
std
::
multiplies
<
index_t
>
());
{
batch
,
channels
,
out_width
,
1
});
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
"A"
,
{
batch_count
,
height
,
channels
});
// run cpu
net
.
AddRandomInput
<
DeviceType
::
GPU
,
float
>
(
net
.
RunOp
();
"B"
,
{
batch_count
,
channels
,
out_width
});
// Check
Tensor
expected
;
expected
.
Copy
(
*
net
.
GetOutput
(
"Output"
));
// Run on opencl
// Run on opencl
BufferToImage
<
DeviceType
::
GPU
,
T
>
(
&
net
,
"A"
,
"AImage"
,
BufferToImage
<
DeviceType
::
GPU
,
T
>
(
&
net
,
"A"
,
"AImage"
,
...
@@ -150,11 +141,40 @@ void Complex(const index_t batch,
...
@@ -150,11 +141,40 @@ void Complex(const index_t batch,
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
.
Finalize
(
net
.
NewOperatorDef
());
// Run on opencl
net
.
RunOp
(
DeviceType
::
GPU
);
net
.
RunOp
(
DeviceType
::
GPU
);
ImageToBuffer
<
DeviceType
::
GPU
,
float
>
(
&
net
,
"OutputImage"
,
"OPENCLOutput"
,
ImageToBuffer
<
DeviceType
::
GPU
,
float
>
(
&
net
,
"OutputImage"
,
"OPENCLOutput"
,
kernels
::
BufferType
::
IN_OUT_HEIGHT
);
kernels
::
BufferType
::
IN_OUT_HEIGHT
);
// run cpu
std
::
vector
<
index_t
>
shape_a
=
batch
;
shape_a
.
push_back
(
height
);
shape_a
.
push_back
(
channels
);
std
::
vector
<
index_t
>
shape_b
=
batch
;
shape_b
.
push_back
(
channels
);
shape_b
.
push_back
(
out_width
);
std
::
vector
<
index_t
>
expected_output_shape
=
batch
;
expected_output_shape
.
push_back
(
height
);
expected_output_shape
.
push_back
(
out_width
);
net
.
GetTensor
(
"A"
)
->
Reshape
(
shape_a
);
net
.
GetTensor
(
"B"
)
->
Reshape
(
shape_b
);
OpDefBuilder
(
"MatMul"
,
"MatMulTest"
)
.
Input
(
"A"
)
.
Input
(
"B"
)
.
Output
(
"Output"
)
.
Finalize
(
net
.
NewOperatorDef
());
net
.
RunOp
();
// Check
EXPECT_EQ
(
expected_output_shape
,
net
.
GetOutput
(
"Output"
)
->
shape
());
Tensor
expected
;
expected
.
Copy
(
*
net
.
GetOutput
(
"Output"
));
expected
.
Reshape
({
batch_count
,
height
,
out_width
});
if
(
DataTypeToEnum
<
T
>::
value
==
DataType
::
DT_HALF
)
{
if
(
DataTypeToEnum
<
T
>::
value
==
DataType
::
DT_HALF
)
{
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"OPENCLOutput"
),
1e-2
,
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"OPENCLOutput"
),
1e-2
,
1e-1
);
1e-1
);
...
@@ -166,28 +186,36 @@ void Complex(const index_t batch,
...
@@ -166,28 +186,36 @@ void Complex(const index_t batch,
}
// namespace
}
// namespace
TEST_F
(
MatMulOpTest
,
OPENCLAlignedWithoutBatch
)
{
TEST_F
(
MatMulOpTest
,
OPENCLAlignedWithoutBatch
)
{
Complex
<
float
>
(
1
,
64
,
128
,
32
);
Complex
<
float
>
({
1
},
64
,
128
,
32
);
Complex
<
float
>
(
1
,
64
,
32
,
128
);
Complex
<
float
>
({
1
},
64
,
32
,
128
);
Complex
<
float
>
({
2
,
3
},
64
,
32
,
128
);
}
}
TEST_F
(
MatMulOpTest
,
OPENCLUnAlignedWithoutBatch
)
{
TEST_F
(
MatMulOpTest
,
OPENCLUnAlignedWithoutBatch
)
{
Complex
<
float
>
(
1
,
31
,
113
,
61
);
Complex
<
float
>
({
1
},
31
,
113
,
61
);
Complex
<
float
>
(
1
,
113
,
31
,
73
);
Complex
<
float
>
({
1
},
113
,
31
,
73
);
Complex
<
float
>
({
2
,
3
},
113
,
31
,
73
);
}
}
TEST_F
(
MatMulOpTest
,
OPENCLUnAlignedWithBatch
)
{
TEST_F
(
MatMulOpTest
,
OPENCLUnAlignedWithBatch
)
{
Complex
<
float
>
(
2
,
3
,
3
,
3
);
Complex
<
float
>
({
2
},
3
,
3
,
3
);
Complex
<
float
>
(
16
,
31
,
61
,
67
);
Complex
<
float
>
({
16
},
31
,
61
,
67
);
Complex
<
float
>
(
31
,
31
,
61
,
67
);
Complex
<
float
>
({
31
},
31
,
61
,
67
);
Complex
<
float
>
({
2
,
3
},
31
,
61
,
67
);
}
}
TEST_F
(
MatMulOpTest
,
OPENCLHalfAlignedWithoutBatch
)
{
TEST_F
(
MatMulOpTest
,
OPENCLHalfAlignedWithoutBatch
)
{
Complex
<
half
>
(
1
,
64
,
128
,
32
);
Complex
<
half
>
({
1
},
64
,
128
,
32
);
Complex
<
half
>
(
1
,
64
,
32
,
128
);
Complex
<
half
>
({
1
},
64
,
32
,
128
);
Complex
<
half
>
({
2
,
3
},
64
,
32
,
128
);
}
}
TEST_F
(
MatMulOpTest
,
OPENCLHalfUnAlignedWithBatch
)
{
TEST_F
(
MatMulOpTest
,
OPENCLHalfUnAlignedWithBatch
)
{
Complex
<
half
>
(
2
,
31
,
113
,
61
);
Complex
<
half
>
({
2
},
31
,
113
,
61
);
Complex
<
half
>
(
16
,
32
,
64
,
64
);
Complex
<
half
>
({
16
},
32
,
64
,
64
);
Complex
<
half
>
(
31
,
31
,
61
,
67
);
Complex
<
half
>
({
31
},
31
,
61
,
67
);
Complex
<
half
>
({
2
,
3
},
31
,
61
,
67
);
}
}
// TODO(liyin): test transpose after implementing gpu runtime
// now transpose test is in kernels_test
}
// namespace test
}
// namespace test
}
// namespace ops
}
// namespace ops
}
// namespace mace
}
// namespace mace
mace/python/tools/converter_tool/transformer.py
浏览文件 @
cd506756
...
@@ -518,7 +518,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -518,7 +518,7 @@ class Transformer(base_converter.ConverterInterface):
wt_output_width
=
batch
*
(
wt_output_width
=
batch
*
(
(
out_height
+
1
)
/
2
)
*
((
out_width
+
1
)
/
2
)
(
out_height
+
1
)
/
2
)
*
((
out_width
+
1
)
/
2
)
wt_output_shape
.
dims
.
extend
(
wt_output_shape
.
dims
.
extend
(
[
16
,
in_channels
,
wt_output_width
,
1
])
[
16
,
in_channels
,
wt_output_width
])
if
ConverterUtil
.
get_arg
(
op
,
if
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_padding_str
)
\
MaceKeyword
.
mace_padding_str
)
\
...
@@ -543,7 +543,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -543,7 +543,7 @@ class Transformer(base_converter.ConverterInterface):
matmul_op
.
output
.
extend
([
matmul_op
.
name
])
matmul_op
.
output
.
extend
([
matmul_op
.
name
])
matmul_output_shape
=
matmul_op
.
output_shape
.
add
()
matmul_output_shape
=
matmul_op
.
output_shape
.
add
()
matmul_output_shape
.
dims
.
extend
(
matmul_output_shape
.
dims
.
extend
(
[
16
,
out_channels
,
wt_output_width
,
1
])
[
16
,
out_channels
,
wt_output_width
])
arg
=
matmul_op
.
arg
.
add
()
arg
=
matmul_op
.
arg
.
add
()
arg
.
name
=
MaceKeyword
.
mace_winograd_filter_transformed
arg
.
name
=
MaceKeyword
.
mace_winograd_filter_transformed
...
...
mace/python/tools/memory_optimizer.py
浏览文件 @
cd506756
...
@@ -167,7 +167,7 @@ class GPUMemoryOptimizer(MemoryOptimizer):
...
@@ -167,7 +167,7 @@ class GPUMemoryOptimizer(MemoryOptimizer):
def
get_op_mem_block
(
self
,
op_type
,
output_shape
):
def
get_op_mem_block
(
self
,
op_type
,
output_shape
):
mem_block
=
[
0
,
0
]
mem_block
=
[
0
,
0
]
if
op_type
==
'WinogradTransform'
or
op_type
==
'MatMul'
:
if
op_type
==
'WinogradTransform'
or
op_type
==
'MatMul'
:
mem_block
[
0
]
=
output_shape
[
2
]
*
output_shape
[
3
]
mem_block
[
0
]
=
output_shape
[
2
]
mem_block
[
1
]
=
output_shape
[
0
]
*
int
((
output_shape
[
1
]
+
3
)
/
4
)
mem_block
[
1
]
=
output_shape
[
0
]
*
int
((
output_shape
[
1
]
+
3
)
/
4
)
else
:
else
:
mem_block
[
0
]
=
output_shape
[
2
]
*
int
((
output_shape
[
3
]
+
3
)
/
4
)
mem_block
[
0
]
=
output_shape
[
2
]
*
int
((
output_shape
[
3
]
+
3
)
/
4
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录