Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
db75d542
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
db75d542
编写于
4月 04, 2018
作者:
吴
吴承辉
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'gemm-asm' into 'master'
Implement ASM GEMM See merge request !354
上级
1fdb5593
68401a67
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
358 addition
and
69 deletion
+358
-69
mace/kernels/arm/conv_2d.cc
mace/kernels/arm/conv_2d.cc
+2
-1
mace/kernels/arm/conv_winograd.cc
mace/kernels/arm/conv_winograd.cc
+39
-38
mace/kernels/arm/conv_winograd.h
mace/kernels/arm/conv_winograd.h
+9
-0
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+300
-29
mace/kernels/gemm.h
mace/kernels/gemm.h
+7
-0
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+1
-1
未找到文件。
mace/kernels/arm/conv_2d.cc
浏览文件 @
db75d542
...
@@ -162,7 +162,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
...
@@ -162,7 +162,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
if
(
USE_WINOGRAD
&&
filter_h
==
3
&&
filter_w
==
3
&&
stride_h
==
1
if
(
USE_WINOGRAD
&&
filter_h
==
3
&&
filter_w
==
3
&&
stride_h
==
1
&&
stride_w
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
)
{
&&
dilation_h
==
1
&&
dilation_w
==
1
&&
input_channels
>=
8
&&
channels
>=
8
)
{
extra_output_height
=
RoundUp
<
index_t
>
(
height
,
2
);
extra_output_height
=
RoundUp
<
index_t
>
(
height
,
2
);
extra_input_height
=
std
::
max
(
padded_input_height
,
extra_output_height
+
2
);
extra_input_height
=
std
::
max
(
padded_input_height
,
extra_output_height
+
2
);
extra_output_width
=
RoundUp
<
index_t
>
(
width
,
2
);
extra_output_width
=
RoundUp
<
index_t
>
(
width
,
2
);
...
...
mace/kernels/arm/conv_winograd.cc
浏览文件 @
db75d542
...
@@ -271,44 +271,6 @@ void TransformOutput(const float *input,
...
@@ -271,44 +271,6 @@ void TransformOutput(const float *input,
}
}
}
}
}
}
void
ConvRef3x3s1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
)
{
index_t
out_height
=
in_height
-
2
;
index_t
out_width
=
in_width
-
2
;
#pragma omp parallel for collapse(4)
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
m
=
0
;
m
<
out_channels
;
++
m
)
{
for
(
index_t
h
=
0
;
h
<
out_height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
++
w
)
{
index_t
out_offset
=
((
b
*
out_channels
+
m
)
*
out_height
+
h
)
*
out_width
+
w
;
output
[
out_offset
]
=
0
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
for
(
index_t
kh
=
0
;
kh
<
3
;
++
kh
)
{
for
(
index_t
kw
=
0
;
kw
<
3
;
++
kw
)
{
index_t
ih
=
h
+
kh
;
index_t
iw
=
w
+
kw
;
index_t
in_offset
=
((
b
*
in_channels
+
c
)
*
in_height
+
ih
)
*
in_width
+
iw
;
index_t
filter_offset
=
(((
m
*
in_channels
)
+
c
)
*
3
+
kh
)
*
3
+
kw
;
output
[
out_offset
]
+=
input
[
in_offset
]
*
filter
[
filter_offset
];
}
}
}
}
}
}
}
}
}
// namespace
}
// namespace
void
WinoGradConv3x3s1
(
const
float
*
input
,
void
WinoGradConv3x3s1
(
const
float
*
input
,
...
@@ -400,5 +362,44 @@ void WinoGradConv3x3s1(const float *input,
...
@@ -400,5 +362,44 @@ void WinoGradConv3x3s1(const float *input,
delete
[]
transformed_output
;
delete
[]
transformed_output
;
}
}
void
ConvRef3x3s1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
)
{
index_t
out_height
=
in_height
-
2
;
index_t
out_width
=
in_width
-
2
;
#pragma omp parallel for collapse(4)
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
m
=
0
;
m
<
out_channels
;
++
m
)
{
for
(
index_t
h
=
0
;
h
<
out_height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
++
w
)
{
index_t
out_offset
=
((
b
*
out_channels
+
m
)
*
out_height
+
h
)
*
out_width
+
w
;
output
[
out_offset
]
=
0
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
for
(
index_t
kh
=
0
;
kh
<
3
;
++
kh
)
{
for
(
index_t
kw
=
0
;
kw
<
3
;
++
kw
)
{
index_t
ih
=
h
+
kh
;
index_t
iw
=
w
+
kw
;
index_t
in_offset
=
((
b
*
in_channels
+
c
)
*
in_height
+
ih
)
*
in_width
+
iw
;
index_t
filter_offset
=
(((
m
*
in_channels
)
+
c
)
*
3
+
kh
)
*
3
+
kw
;
output
[
out_offset
]
+=
input
[
in_offset
]
*
filter
[
filter_offset
];
}
}
}
}
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/arm/conv_winograd.h
浏览文件 @
db75d542
...
@@ -36,6 +36,15 @@ void WinoGradConv3x3s1(const float *input,
...
@@ -36,6 +36,15 @@ void WinoGradConv3x3s1(const float *input,
bool
is_filter_transformed
,
bool
is_filter_transformed
,
float
*
output
);
float
*
output
);
void
ConvRef3x3s1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
...
...
mace/kernels/gemm.cc
浏览文件 @
db75d542
...
@@ -13,22 +13,6 @@ namespace mace {
...
@@ -13,22 +13,6 @@ namespace mace {
namespace
kernels
{
namespace
kernels
{
namespace
{
namespace
{
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
index_t
height
,
const
index_t
K
,
const
index_t
width
,
float
*
C
)
{
memset
(
C
,
0
,
sizeof
(
float
)
*
height
*
width
);
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
width
;
++
j
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
C
[
i
*
width
+
j
]
+=
A
[
i
*
K
+
k
]
*
B
[
k
*
width
+
j
];
}
}
}
}
inline
void
GemmBlock
(
const
float
*
A
,
inline
void
GemmBlock
(
const
float
*
A
,
const
float
*
B
,
const
float
*
B
,
const
index_t
height
,
const
index_t
height
,
...
@@ -49,8 +33,8 @@ inline void GemmBlock(const float *A,
...
@@ -49,8 +33,8 @@ inline void GemmBlock(const float *A,
// TODO(liyin): may need implement 883 since RGB
// TODO(liyin): may need implement 883 since RGB
inline
void
Gemm884
(
const
float
*
a_ptr
,
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
float
*
b_ptr
,
index_t
stride_w
,
index_t
stride_k
,
index_t
stride_k
,
index_t
stride_w
,
float
*
c_ptr
)
{
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
...
@@ -136,29 +120,300 @@ inline void GemmTile(const float *A,
...
@@ -136,29 +120,300 @@ inline void GemmTile(const float *A,
float
*
C
)
{
float
*
C
)
{
index_t
h
,
w
,
k
;
index_t
h
,
w
,
k
;
for
(
h
=
0
;
h
+
7
<
height
;
h
+=
8
)
{
for
(
h
=
0
;
h
+
7
<
height
;
h
+=
8
)
{
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
for
(
k
=
0
;
k
+
7
<
K
;
k
+=
8
)
{
for
(
k
=
0
;
k
+
7
<
K
;
k
+=
8
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#ifdef __clang__
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
// load A
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
a15
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
4
);
a2
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
);
a3
=
vld1q_f32
(
a_ptr
+
1
*
stride_k
+
4
);
a4
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
);
a5
=
vld1q_f32
(
a_ptr
+
2
*
stride_k
+
4
);
a6
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
);
a7
=
vld1q_f32
(
a_ptr
+
3
*
stride_k
+
4
);
a8
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
);
a9
=
vld1q_f32
(
a_ptr
+
4
*
stride_k
+
4
);
a10
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
);
a11
=
vld1q_f32
(
a_ptr
+
5
*
stride_k
+
4
);
a12
=
vld1q_f32
(
a_ptr
+
6
*
stride_k
);
a13
=
vld1q_f32
(
a_ptr
+
6
*
stride_k
+
4
);
a14
=
vld1q_f32
(
a_ptr
+
7
*
stride_k
);
a15
=
vld1q_f32
(
a_ptr
+
7
*
stride_k
+
4
);
const
float
*
b_ptr0
=
B
+
k
*
stride_w
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_w
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_w
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_w
;
const
float
*
b_ptr4
=
B
+
(
k
+
4
)
*
stride_w
;
const
float
*
b_ptr5
=
B
+
(
k
+
5
)
*
stride_w
;
const
float
*
b_ptr6
=
B
+
(
k
+
6
)
*
stride_w
;
const
float
*
b_ptr7
=
B
+
(
k
+
7
)
*
stride_w
;
float
*
c_ptr0
=
C
+
h
*
stride_w
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_w
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_w
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_w
;
float
*
c_ptr4
=
C
+
(
h
+
4
)
*
stride_w
;
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_w
;
float
*
c_ptr6
=
C
+
(
h
+
6
)
*
stride_w
;
float
*
c_ptr7
=
C
+
(
h
+
7
)
*
stride_w
;
asm
volatile
(
"0:
\n
"
"prfm pldl1keep, [%1, #128]
\n
"
"ld1 {v24.4s}, [%1]
\n
"
// load b: 0-7
"prfm pldl1keep, [%9, #128]
\n
"
"ld1 {v16.4s}, [%9], #16
\n
"
"prfm pldl1keep, [%10, #128]
\n
"
"ld1 {v17.4s}, [%10], #16
\n
"
"prfm pldl1keep, [%11, #128]
\n
"
"ld1 {v18.4s}, [%11], #16
\n
"
"prfm pldl1keep, [%12, #128]
\n
"
"ld1 {v19.4s}, [%12], #16
\n
"
"prfm pldl1keep, [%2, #128]
\n
"
"ld1 {v25.4s}, [%2]
\n
"
"prfm pldl1keep, [%13, #128]
\n
"
"ld1 {v20.4s}, [%13], #16
\n
"
"prfm pldl1keep, [%14, #128]
\n
"
"ld1 {v21.4s}, [%14], #16
\n
"
"prfm pldl1keep, [%15, #128]
\n
"
"ld1 {v22.4s}, [%15], #16
\n
"
"prfm pldl1keep, [%16, #128]
\n
"
"ld1 {v23.4s}, [%16], #16
\n
"
"prfm pldl1keep, [%3, #128]
\n
"
"ld1 {v26.4s}, [%3]
\n
"
"fmla v24.4s, v16.4s, %34.s[0]
\n
"
"fmla v24.4s, v17.4s, %34.s[1]
\n
"
"fmla v24.4s, v18.4s, %34.s[2]
\n
"
"fmla v24.4s, v19.4s, %34.s[3]
\n
"
"fmla v24.4s, v20.4s, %35.s[0]
\n
"
"fmla v24.4s, v21.4s, %35.s[1]
\n
"
"fmla v24.4s, v22.4s, %35.s[2]
\n
"
"fmla v24.4s, v23.4s, %35.s[3]
\n
"
"st1 {v24.4s}, [%1], #16
\n
"
"fmla v25.4s, v16.4s, %36.s[0]
\n
"
"fmla v25.4s, v17.4s, %36.s[1]
\n
"
"fmla v25.4s, v18.4s, %36.s[2]
\n
"
"fmla v25.4s, v19.4s, %36.s[3]
\n
"
"fmla v25.4s, v20.4s, %37.s[0]
\n
"
"fmla v25.4s, v21.4s, %37.s[1]
\n
"
"fmla v25.4s, v22.4s, %37.s[2]
\n
"
"fmla v25.4s, v23.4s, %37.s[3]
\n
"
"prfm pldl1keep, [%4, #128]
\n
"
"ld1 {v24.4s}, [%4]
\n
"
"st1 {v25.4s}, [%2], #16
\n
"
"fmla v26.4s, v16.4s, %38.s[0]
\n
"
"fmla v26.4s, v17.4s, %38.s[1]
\n
"
"fmla v26.4s, v18.4s, %38.s[2]
\n
"
"fmla v26.4s, v19.4s, %38.s[3]
\n
"
"fmla v26.4s, v20.4s, %39.s[0]
\n
"
"fmla v26.4s, v21.4s, %39.s[1]
\n
"
"fmla v26.4s, v22.4s, %39.s[2]
\n
"
"fmla v26.4s, v23.4s, %39.s[3]
\n
"
"prfm pldl1keep, [%5, #128]
\n
"
"ld1 {v25.4s}, [%5]
\n
"
"st1 {v26.4s}, [%3], #16
\n
"
"fmla v24.4s, v16.4s, %40.s[0]
\n
"
"fmla v24.4s, v17.4s, %40.s[1]
\n
"
"fmla v24.4s, v18.4s, %40.s[2]
\n
"
"fmla v24.4s, v19.4s, %40.s[3]
\n
"
"fmla v24.4s, v20.4s, %41.s[0]
\n
"
"fmla v24.4s, v21.4s, %41.s[1]
\n
"
"fmla v24.4s, v22.4s, %41.s[2]
\n
"
"fmla v24.4s, v23.4s, %41.s[3]
\n
"
"prfm pldl1keep, [%6, #128]
\n
"
"ld1 {v26.4s}, [%6]
\n
"
"st1 {v24.4s}, [%4], #16
\n
"
"fmla v25.4s, v16.4s, %42.s[0]
\n
"
"fmla v25.4s, v17.4s, %42.s[1]
\n
"
"fmla v25.4s, v18.4s, %42.s[2]
\n
"
"fmla v25.4s, v19.4s, %42.s[3]
\n
"
"fmla v25.4s, v20.4s, %43.s[0]
\n
"
"fmla v25.4s, v21.4s, %43.s[1]
\n
"
"fmla v25.4s, v22.4s, %43.s[2]
\n
"
"fmla v25.4s, v23.4s, %43.s[3]
\n
"
"prfm pldl1keep, [%7, #128]
\n
"
"ld1 {v24.4s}, [%7]
\n
"
"st1 {v25.4s}, [%5], #16
\n
"
"fmla v26.4s, v16.4s, %44.s[0]
\n
"
"fmla v26.4s, v17.4s, %44.s[1]
\n
"
"fmla v26.4s, v18.4s, %44.s[2]
\n
"
"fmla v26.4s, v19.4s, %44.s[3]
\n
"
"fmla v26.4s, v20.4s, %45.s[0]
\n
"
"fmla v26.4s, v21.4s, %45.s[1]
\n
"
"fmla v26.4s, v22.4s, %45.s[2]
\n
"
"fmla v26.4s, v23.4s, %45.s[3]
\n
"
"prfm pldl1keep, [%8, #128]
\n
"
"ld1 {v25.4s}, [%8]
\n
"
"st1 {v26.4s}, [%6], #16
\n
"
"fmla v24.4s, v16.4s, %46.s[0]
\n
"
"fmla v24.4s, v17.4s, %46.s[1]
\n
"
"fmla v24.4s, v18.4s, %46.s[2]
\n
"
"fmla v24.4s, v19.4s, %46.s[3]
\n
"
"fmla v24.4s, v20.4s, %47.s[0]
\n
"
"fmla v24.4s, v21.4s, %47.s[1]
\n
"
"fmla v24.4s, v22.4s, %47.s[2]
\n
"
"fmla v24.4s, v23.4s, %47.s[3]
\n
"
"st1 {v24.4s}, [%7], #16
\n
"
"fmla v25.4s, v16.4s, %48.s[0]
\n
"
"fmla v25.4s, v17.4s, %48.s[1]
\n
"
"fmla v25.4s, v18.4s, %48.s[2]
\n
"
"fmla v25.4s, v19.4s, %48.s[3]
\n
"
"fmla v25.4s, v20.4s, %49.s[0]
\n
"
"fmla v25.4s, v21.4s, %49.s[1]
\n
"
"fmla v25.4s, v22.4s, %49.s[2]
\n
"
"fmla v25.4s, v23.4s, %49.s[3]
\n
"
"st1 {v25.4s}, [%8], #16
\n
"
"subs %w0, %w0, #1
\n
"
"bne 0b
\n
"
:
"=r"
(
nw
),
// 0
"=r"
(
c_ptr0
),
// 1
"=r"
(
c_ptr1
),
// 2
"=r"
(
c_ptr2
),
// 3
"=r"
(
c_ptr3
),
// 4
"=r"
(
c_ptr4
),
// 5
"=r"
(
c_ptr5
),
// 6
"=r"
(
c_ptr6
),
// 7
"=r"
(
c_ptr7
),
// 8
"=r"
(
b_ptr0
),
// 9
"=r"
(
b_ptr1
),
// 10
"=r"
(
b_ptr2
),
// 11
"=r"
(
b_ptr3
),
// 12
"=r"
(
b_ptr4
),
// 13
"=r"
(
b_ptr5
),
// 14
"=r"
(
b_ptr6
),
// 15
"=r"
(
b_ptr7
)
// 16
:
"0"
(
nw
),
// 17
"1"
(
c_ptr0
),
// 18
"2"
(
c_ptr1
),
// 19
"3"
(
c_ptr2
),
// 20
"4"
(
c_ptr3
),
// 21
"5"
(
c_ptr4
),
// 22
"6"
(
c_ptr5
),
// 23
"7"
(
c_ptr6
),
// 24
"8"
(
c_ptr7
),
// 25
"9"
(
b_ptr0
),
// 26
"10"
(
b_ptr1
),
// 27
"11"
(
b_ptr2
),
// 28
"12"
(
b_ptr3
),
// 29
"13"
(
b_ptr4
),
// 30
"14"
(
b_ptr5
),
// 31
"15"
(
b_ptr6
),
// 32
"16"
(
b_ptr7
),
// 33
"w"
(
a0
),
// 34
"w"
(
a1
),
// 35
"w"
(
a2
),
// 36
"w"
(
a3
),
// 37
"w"
(
a4
),
// 38
"w"
(
a5
),
// 39
"w"
(
a6
),
// 40
"w"
(
a7
),
// 41
"w"
(
a8
),
// 42
"w"
(
a9
),
// 43
"w"
(
a10
),
// 44
"w"
(
a11
),
// 45
"w"
(
a12
),
// 46
"w"
(
a13
),
// 47
"w"
(
a14
),
// 48
"w"
(
a15
)
// 49
:
"cc"
,
"memory"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
);
w
=
(
width
>>
2
)
<<
2
;
}
#else // gcc
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
}
#endif
#else
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
Gemm
884
(
a_ptr
,
b_ptr
,
stride_w
,
stride_k
,
c_ptr
);
Gemm
Block
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
}
}
if
(
k
<
K
)
{
#endif
if
(
w
<
width
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
4
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
}
}
}
}
if
(
w
<
width
)
{
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
h
*
stride_k
;
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
)
;
const
float
*
b_ptr
=
B
+
w
;
const
float
*
b_ptr
=
B
+
k
*
stride_
w
;
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
)
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
GemmBlock
(
a_ptr
,
b_ptr
,
b_ptr
,
8
,
8
,
K
,
K
-
k
,
width
-
w
,
width
,
stride_k
,
stride_k
,
stride_w
,
stride_w
,
c_ptr
);
c_ptr
);
...
@@ -243,5 +498,21 @@ void Gemm(const float *A,
...
@@ -243,5 +498,21 @@ void Gemm(const float *A,
}
// n
}
// n
}
}
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
index_t
height
,
const
index_t
K
,
const
index_t
width
,
float
*
C
)
{
memset
(
C
,
0
,
sizeof
(
float
)
*
height
*
width
);
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
width
;
++
j
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
C
[
i
*
width
+
j
]
+=
A
[
i
*
K
+
k
]
*
B
[
k
*
width
+
j
];
}
}
}
}
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/gemm.h
浏览文件 @
db75d542
...
@@ -22,6 +22,13 @@ void Gemm(const float *A,
...
@@ -22,6 +22,13 @@ void Gemm(const float *A,
const
index_t
width
,
const
index_t
width
,
float
*
C
);
float
*
C
);
void
GemmRef
(
const
float
*
A
,
const
float
*
B
,
const
index_t
height
,
const
index_t
K
,
const
index_t
width
,
float
*
C
);
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
...
...
mace/kernels/gemm_test.cc
浏览文件 @
db75d542
...
@@ -31,7 +31,7 @@ TEST(GEMMTest, gemm) {
...
@@ -31,7 +31,7 @@ TEST(GEMMTest, gemm) {
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
return
nd
(
gen
);
});
});
kernels
::
Gemm
(
A
,
B
,
N
,
K
,
M
,
C
);
kernels
::
Gemm
(
A
,
B
,
1
,
N
,
K
,
M
,
C
);
kernels
::
GemmRef
(
A
,
B
,
N
,
K
,
M
,
C_ref
);
kernels
::
GemmRef
(
A
,
B
,
N
,
K
,
M
,
C_ref
);
for
(
int
i
=
0
;
i
<
N
*
M
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
*
M
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录