Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
47bff05a
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
47bff05a
编写于
4月 17, 2018
作者:
吴
吴承辉
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'armv7' into 'master'
GEMM Neon v7 See merge request !385
上级
680f8b42
81a2ab7c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
160 addition
and
29 deletion
+160
-29
mace/kernels/BUILD
mace/kernels/BUILD
+2
-2
mace/kernels/arm/conv_winograd_test.cc
mace/kernels/arm/conv_winograd_test.cc
+24
-13
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+132
-12
mace/ops/BUILD
mace/ops/BUILD
+2
-2
未找到文件。
mace/kernels/BUILD
浏览文件 @
47bff05a
...
@@ -7,7 +7,7 @@ package(
...
@@ -7,7 +7,7 @@ package(
licenses
([
"notice"
])
# Apache 2.0
licenses
([
"notice"
])
# Apache 2.0
load
(
"//mace:mace.bzl"
,
"if_android"
,
"if_neon_enabled"
,
"if_openmp_enabled"
)
load
(
"//mace:mace.bzl"
,
"if_android"
,
"if_neon_enabled"
,
"if_openmp_enabled"
,
"if_android_armv7"
)
cc_library
(
cc_library
(
name
=
"kernels"
,
name
=
"kernels"
,
...
@@ -28,7 +28,7 @@ cc_library(
...
@@ -28,7 +28,7 @@ cc_library(
"opencl/*.h"
,
"opencl/*.h"
,
"arm/*.h"
,
"arm/*.h"
,
]),
]),
copts
=
if_openmp_enabled
([
"-fopenmp"
])
+
if_neon_enabled
([
"-DMACE_ENABLE_NEON"
]),
copts
=
if_openmp_enabled
([
"-fopenmp"
])
+
if_neon_enabled
([
"-DMACE_ENABLE_NEON"
])
+
if_android_armv7
([
"-mfpu=neon -mfloat-abi=softfp"
])
,
linkopts
=
if_android
([
"-lm"
]),
linkopts
=
if_android
([
"-lm"
]),
deps
=
[
deps
=
[
"//mace/core"
,
"//mace/core"
,
...
...
mace/kernels/arm/conv_winograd_test.cc
浏览文件 @
47bff05a
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/core/types.h"
#include "mace/core/types.h"
#include "mace/core/tensor.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
...
@@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) {
...
@@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) {
index_t
out_height
=
in_height
-
2
;
index_t
out_height
=
in_height
-
2
;
index_t
out_width
=
in_width
-
2
;
index_t
out_width
=
in_width
-
2
;
index_t
input_size
=
batch
*
in_channels
*
in_height
*
out_height
;
index_t
input_size
=
batch
*
in_channels
*
in_height
*
in_width
;
index_t
filter_size
=
3
*
3
*
in_channels
*
out_channels
;
index_t
filter_size
=
3
*
3
*
in_channels
*
out_channels
;
index_t
output_size
=
batch
*
out_channels
*
out_height
*
out_width
;
index_t
output_size
=
batch
*
out_channels
*
out_height
*
out_width
;
std
::
unique_ptr
<
float
[]
>
input_data
(
new
float
[
input_size
]);
Tensor
input
;
std
::
unique_ptr
<
float
[]
>
filter_data
(
new
float
[
filter_size
]);
Tensor
filter
;
std
::
unique_ptr
<
float
[]
>
output_data
(
new
float
[
output_size
]);
Tensor
output
;
std
::
unique_ptr
<
float
[]
>
output_data_ref
(
new
float
[
output_size
]);
Tensor
output_ref
;
input
.
Resize
({
batch
,
in_channels
,
in_height
,
in_width
});
filter
.
Resize
({
out_channels
,
in_channels
,
3
,
3
});
output
.
Resize
({
batch
,
out_channels
,
out_height
,
out_width
});
output_ref
.
Resize
({
batch
,
out_channels
,
out_height
,
out_width
});
float
*
input_data
=
input
.
mutable_data
<
float
>
();
float
*
filter_data
=
filter
.
mutable_data
<
float
>
();
float
*
output_data
=
output
.
mutable_data
<
float
>
();
float
*
output_data_ref
=
output
.
mutable_data
<
float
>
();
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
input_data
.
get
(),
input_data
.
get
()
+
input_size
,
std
::
generate
(
input_data
,
input_data
+
input_size
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
std
::
max
(
-
1.0
f
,
std
::
min
(
1.0
f
,
nd
(
gen
)));
return
std
::
max
(
-
1.0
f
,
std
::
min
(
1.0
f
,
nd
(
gen
)));
});
});
std
::
generate
(
filter_data
.
get
(),
filter_data
.
get
()
+
filter_size
,
std
::
generate
(
filter_data
,
filter_data
+
filter_size
,
[
&
gen
,
&
nd
]
{
[
&
gen
,
&
nd
]
{
return
std
::
max
(
-
1.0
f
,
std
::
min
(
1.0
f
,
nd
(
gen
)));
return
std
::
max
(
-
1.0
f
,
std
::
min
(
1.0
f
,
nd
(
gen
)));
});
});
kernels
::
ConvRef3x3s1
(
input_data
.
get
()
,
kernels
::
ConvRef3x3s1
(
input_data
,
filter_data
.
get
()
,
filter_data
,
batch
,
batch
,
in_height
,
in_height
,
in_width
,
in_width
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
output_data_ref
.
get
()
);
output_data_ref
);
kernels
::
WinoGradConv3x3s1
(
input_data
.
get
()
,
kernels
::
WinoGradConv3x3s1
(
input_data
,
filter_data
.
get
()
,
filter_data
,
batch
,
batch
,
in_height
,
in_height
,
in_width
,
in_width
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
6
,
6
,
output_data
.
get
()
);
output_data
);
// test
// test
for
(
index_t
i
=
0
;
i
<
output_size
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
output_size
;
++
i
)
{
...
...
mace/kernels/gemm.cc
浏览文件 @
47bff05a
...
@@ -5,10 +5,16 @@
...
@@ -5,10 +5,16 @@
#include <math.h>
#include <math.h>
#include <algorithm>
#include <algorithm>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/gemm.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
#include "mace/utils/logging.h"
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
...
@@ -119,12 +125,11 @@ inline void GemmTile(const float *A,
...
@@ -119,12 +125,11 @@ inline void GemmTile(const float *A,
const
index_t
stride_w
,
const
index_t
stride_w
,
float
*
C
)
{
float
*
C
)
{
index_t
h
,
w
,
k
;
index_t
h
,
w
,
k
;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for
(
h
=
0
;
h
+
7
<
height
;
h
+=
8
)
{
for
(
h
=
0
;
h
+
7
<
height
;
h
+=
8
)
{
for
(
k
=
0
;
k
+
7
<
K
;
k
+=
8
)
{
for
(
k
=
0
;
k
+
7
<
K
;
k
+=
8
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#ifdef __clang__
#ifdef __clang__
int
nw
=
width
>>
2
;
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
if
(
nw
>
0
)
{
...
@@ -388,21 +393,132 @@ inline void GemmTile(const float *A,
...
@@ -388,21 +393,132 @@ inline void GemmTile(const float *A,
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_k
,
stride_w
,
c_ptr
);
}
}
#endif
#endif // clang
if
(
w
<
width
)
{
#else
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
4
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
}
}
#endif
}
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_k
,
stride_w
,
c_ptr
);
}
}
if
(
h
<
height
)
{
// TODO(liyin): may use Gemm444
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
);
const
float
*
b_ptr
=
B
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
b_ptr
,
height
-
h
,
K
,
width
,
stride_k
,
stride_w
,
c_ptr
);
}
#else
#if defined(MACE_ENABLE_NEON) // armv7
for
(
h
=
0
;
h
+
3
<
height
;
h
+=
4
)
{
for
(
k
=
0
;
k
+
3
<
K
;
k
+=
4
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
// load A
float32x2_t
a00
,
a01
,
a10
,
a11
,
a20
,
a21
,
a30
,
a31
;
a00
=
vld1_f32
(
a_ptr
);
a01
=
vld1_f32
(
a_ptr
+
2
);
a10
=
vld1_f32
(
a_ptr
+
1
*
stride_k
);
a11
=
vld1_f32
(
a_ptr
+
1
*
stride_k
+
2
);
a20
=
vld1_f32
(
a_ptr
+
2
*
stride_k
);
a21
=
vld1_f32
(
a_ptr
+
2
*
stride_k
+
2
);
a30
=
vld1_f32
(
a_ptr
+
3
*
stride_k
);
a31
=
vld1_f32
(
a_ptr
+
3
*
stride_k
+
2
);
const
float
*
b_ptr0
=
B
+
k
*
stride_w
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_w
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_w
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_w
;
float
*
c_ptr0
=
C
+
h
*
stride_w
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_w
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_w
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_w
;
// TODO(liyin): asm v7 prefetch and load optimization
while
(
nw
--
)
{
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
;
c0
=
vld1q_f32
(
c_ptr0
);
b0
=
vld1q_f32
(
b_ptr0
);
b1
=
vld1q_f32
(
b_ptr1
);
b2
=
vld1q_f32
(
b_ptr2
);
b3
=
vld1q_f32
(
b_ptr3
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a00
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b1
,
a00
,
1
);
c0
=
vmlaq_lane_f32
(
c0
,
b2
,
a01
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b3
,
a01
,
1
);
vst1q_f32
(
c_ptr0
,
c0
);
c0
=
vld1q_f32
(
c_ptr1
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a10
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b1
,
a10
,
1
);
c0
=
vmlaq_lane_f32
(
c0
,
b2
,
a11
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b3
,
a11
,
1
);
vst1q_f32
(
c_ptr1
,
c0
);
c0
=
vld1q_f32
(
c_ptr2
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a20
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b1
,
a20
,
1
);
c0
=
vmlaq_lane_f32
(
c0
,
b2
,
a21
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b3
,
a21
,
1
);
vst1q_f32
(
c_ptr2
,
c0
);
c0
=
vld1q_f32
(
c_ptr3
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a30
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b1
,
a30
,
1
);
c0
=
vmlaq_lane_f32
(
c0
,
b2
,
a31
,
0
);
c0
=
vmlaq_lane_f32
(
c0
,
b3
,
a31
,
1
);
vst1q_f32
(
c_ptr3
,
c0
);
b_ptr0
+=
4
;
b_ptr1
+=
4
;
b_ptr2
+=
4
;
b_ptr3
+=
4
;
c_ptr0
+=
4
;
c_ptr1
+=
4
;
c_ptr2
+=
4
;
c_ptr3
+=
4
;
}
w
=
(
width
>>
2
)
<<
2
;
}
if
(
w
<
width
)
{
if
(
w
<
width
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
+
k
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
const
float
*
b_ptr
=
B
+
(
k
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_w
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
4
,
width
-
w
,
stride_k
,
stride_w
,
c_ptr
);
}
}
}
}
if
(
k
<
K
)
{
if
(
k
<
K
)
{
...
@@ -411,7 +527,7 @@ inline void GemmTile(const float *A,
...
@@ -411,7 +527,7 @@ inline void GemmTile(const float *A,
float
*
c_ptr
=
C
+
h
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
GemmBlock
(
a_ptr
,
GemmBlock
(
a_ptr
,
b_ptr
,
b_ptr
,
8
,
4
,
K
-
k
,
K
-
k
,
width
,
width
,
stride_k
,
stride_k
,
...
@@ -420,7 +536,6 @@ inline void GemmTile(const float *A,
...
@@ -420,7 +536,6 @@ inline void GemmTile(const float *A,
}
}
}
}
if
(
h
<
height
)
{
if
(
h
<
height
)
{
// TODO(liyin): may use Gemm444
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
);
const
float
*
a_ptr
=
A
+
(
h
*
stride_k
);
const
float
*
b_ptr
=
B
;
const
float
*
b_ptr
=
B
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
float
*
c_ptr
=
C
+
h
*
stride_w
;
...
@@ -433,6 +548,11 @@ inline void GemmTile(const float *A,
...
@@ -433,6 +548,11 @@ inline void GemmTile(const float *A,
stride_w
,
stride_w
,
c_ptr
);
c_ptr
);
}
}
#else // cpu
GemmBlock
(
A
,
B
,
height
,
K
,
width
,
stride_k
,
stride_w
,
C
);
#endif // armv7
#endif // aarch64
}
}
}
// namespace
}
// namespace
...
...
mace/ops/BUILD
浏览文件 @
47bff05a
...
@@ -7,7 +7,7 @@ package(
...
@@ -7,7 +7,7 @@ package(
licenses
([
"notice"
])
# Apache 2.0
licenses
([
"notice"
])
# Apache 2.0
load
(
"//mace:mace.bzl"
,
"if_android"
,
"if_neon_enabled"
,
"if_openmp_enabled"
)
load
(
"//mace:mace.bzl"
,
"if_android"
,
"if_neon_enabled"
,
"if_openmp_enabled"
,
"if_android_armv7"
)
cc_library
(
cc_library
(
name
=
"test"
,
name
=
"test"
,
...
@@ -34,7 +34,7 @@ cc_library(
...
@@ -34,7 +34,7 @@ cc_library(
[
"*.h"
],
[
"*.h"
],
exclude
=
[
"ops_test_util.h"
],
exclude
=
[
"ops_test_util.h"
],
),
),
copts
=
if_openmp_enabled
([
"-fopenmp"
])
+
if_neon_enabled
([
"-DMACE_ENABLE_NEON"
]),
copts
=
if_openmp_enabled
([
"-fopenmp"
])
+
if_neon_enabled
([
"-DMACE_ENABLE_NEON"
])
+
if_android_armv7
([
"-mfpu=neon -mfloat-abi=softfp"
])
,
deps
=
[
deps
=
[
"//mace/kernels"
,
"//mace/kernels"
,
],
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录