Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
a376a1b8
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a376a1b8
编写于
6月 12, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of v9.git.n.xiaomi.com:deep-computing/mace into add_shared_lib
上级
07787144
6a154ca8
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
622 addition
and
116 deletion
+622
-116
mace/core/mace.cc
mace/core/mace.cc
+30
-1
mace/core/runtime/opencl/opencl_wrapper.cc
mace/core/runtime/opencl/opencl_wrapper.cc
+3
-1
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+474
-110
mace/ops/shape_test.cc
mace/ops/shape_test.cc
+3
-1
mace/ops/strided_slice_test.cc
mace/ops/strided_slice_test.cc
+10
-3
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+2
-0
mace/python/tools/converter_tool/tensorflow_converter.py
mace/python/tools/converter_tool/tensorflow_converter.py
+2
-0
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+18
-0
mace/test/BUILD
mace/test/BUILD
+21
-0
mace/test/mace_api_exception_test.cc
mace/test/mace_api_exception_test.cc
+40
-0
mace/test/mace_api_mt_test.cc
mace/test/mace_api_mt_test.cc
+4
-0
mace/test/mace_api_test.cc
mace/test/mace_api_test.cc
+5
-0
mace/utils/utils.h
mace/utils/utils.h
+10
-0
未找到文件。
mace/core/mace.cc
浏览文件 @
a376a1b8
...
...
@@ -106,6 +106,8 @@ class MaceEngine::Impl {
DeviceType
device_type_
;
std
::
unique_ptr
<
Workspace
>
ws_
;
std
::
unique_ptr
<
NetBase
>
net_
;
std
::
map
<
std
::
string
,
mace
::
InputInfo
>
input_info_map_
;
std
::
map
<
std
::
string
,
mace
::
OutputInfo
>
output_info_map_
;
#ifdef MACE_ENABLE_HEXAGON
std
::
unique_ptr
<
HexagonControlWrapper
>
hexagon_controller_
;
#endif
...
...
@@ -131,12 +133,29 @@ MaceStatus MaceEngine::Impl::Init(
const
std
::
vector
<
std
::
string
>
&
output_nodes
,
const
unsigned
char
*
model_data
)
{
LOG
(
INFO
)
<<
"Initializing MaceEngine"
;
// Get input and output information.
for
(
auto
&
input_info
:
net_def
->
input_info
())
{
input_info_map_
[
input_info
.
name
()]
=
input_info
;
}
for
(
auto
&
output_info
:
net_def
->
output_info
())
{
output_info_map_
[
output_info
.
name
()]
=
output_info
;
}
// Set storage path for internal usage
for
(
auto
input_name
:
input_nodes
)
{
if
(
input_info_map_
.
find
(
input_name
)
==
input_info_map_
.
end
())
{
LOG
(
FATAL
)
<<
"'"
<<
input_name
<<
"' is not belong to model's inputs: "
<<
MakeString
(
MapKeys
(
input_info_map_
));
}
ws_
->
CreateTensor
(
MakeString
(
"mace_input_node_"
,
input_name
),
GetDeviceAllocator
(
device_type_
),
DT_FLOAT
);
}
for
(
auto
output_name
:
output_nodes
)
{
if
(
output_info_map_
.
find
(
output_name
)
==
output_info_map_
.
end
())
{
LOG
(
FATAL
)
<<
"'"
<<
output_name
<<
"' is not belong to model's outputs "
<<
MakeString
(
MapKeys
(
output_info_map_
));
}
ws_
->
CreateTensor
(
MakeString
(
"mace_output_node_"
,
output_name
),
GetDeviceAllocator
(
device_type_
),
DT_FLOAT
);
}
...
...
@@ -193,6 +212,11 @@ MaceStatus MaceEngine::Impl::Run(
std
::
vector
<
Tensor
*>
input_tensors
;
std
::
vector
<
Tensor
*>
output_tensors
;
for
(
auto
&
input
:
inputs
)
{
if
(
input_info_map_
.
find
(
input
.
first
)
==
input_info_map_
.
end
())
{
LOG
(
FATAL
)
<<
"'"
<<
input
.
first
<<
"' is not belong to model's inputs: "
<<
MakeString
(
MapKeys
(
input_info_map_
));
}
MACE_CHECK
(
input
.
second
.
shape
().
size
()
==
4
,
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions"
);
...
...
@@ -208,6 +232,11 @@ MaceStatus MaceEngine::Impl::Run(
input_tensors
.
push_back
(
input_tensor
);
}
for
(
auto
&
output
:
*
outputs
)
{
if
(
output_info_map_
.
find
(
output
.
first
)
==
output_info_map_
.
end
())
{
LOG
(
FATAL
)
<<
"'"
<<
output
.
first
<<
"' is not belong to model's outputs: "
<<
MakeString
(
MapKeys
(
output_info_map_
));
}
if
(
device_type_
==
DeviceType
::
GPU
)
{
MACE_CHECK
(
output
.
second
.
shape
().
size
()
==
4
,
"The outputs' shape must be 4-dimension with NHWC format,"
...
...
@@ -245,7 +274,7 @@ MaceStatus MaceEngine::Impl::Run(
std
::
multiplies
<
int64_t
>
());
MACE_CHECK
(
!
shape
.
empty
())
<<
"Output's shape must greater than 0"
;
MACE_CHECK
(
shape
==
output
.
second
.
shape
())
<<
"Output shape mis
p
atch: "
<<
"Output shape mis
m
atch: "
<<
MakeString
<
int64_t
>
(
output
.
second
.
shape
())
<<
" != "
<<
MakeString
<
int64_t
>
(
shape
);
std
::
memcpy
(
output
.
second
.
data
().
get
(),
output_tensor
->
data
<
float
>
(),
...
...
mace/core/runtime/opencl/opencl_wrapper.cc
浏览文件 @
a376a1b8
...
...
@@ -281,7 +281,9 @@ bool OpenCLLibraryImpl::Load() {
}
if
(
handle_
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Failed to load OpenCL library"
;
LOG
(
ERROR
)
<<
"Failed to load OpenCL library, "
"please make sure there exist OpenCL library on your device, "
"and your APP have right to access the library."
;
return
false
;
}
...
...
mace/kernels/gemm.cc
浏览文件 @
a376a1b8
...
...
@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A,
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL
(RC, RA, RAN)
\
#define MACE_GEMM_PART_CAL
_8(RC, RA, RAN)
\
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
...
...
@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A,
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#define MACE_GEMM_PART_CAL
(RC, RA, RAN)
\
#define MACE_GEMM_PART_CAL
_8(RC, RA, RAN)
\
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
...
...
@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A,
#endif
#endif
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3);
#else
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
#endif
#endif
inline
void
Gemm144
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED
(
stride_a
);
MACE_UNUSED
(
stride_c
);
float32x4_t
a0
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
;
a0
=
vld1q_f32
(
a_ptr
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
MACE_GEMM_PART_CAL_4
(
0
);
vst1q_f32
(
c_ptr
,
c0
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
1
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm244
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
2
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm344
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
3
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm444
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
4
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm544
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
MACE_GEMM_PART_CAL_4
(
4
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_c
,
c4
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
5
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
Gemm644
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
)
{
#if defined(MACE_ENABLE_NEON)
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
,
c5
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
a5
=
vld1q_f32
(
a_ptr
+
5
*
stride_a
);
b0
=
vld1q_f32
(
b_ptr
);
b1
=
vld1q_f32
(
b_ptr
+
1
*
stride_b
);
b2
=
vld1q_f32
(
b_ptr
+
2
*
stride_b
);
b3
=
vld1q_f32
(
b_ptr
+
3
*
stride_b
);
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
MACE_GEMM_PART_CAL_4
(
0
);
MACE_GEMM_PART_CAL_4
(
1
);
MACE_GEMM_PART_CAL_4
(
2
);
MACE_GEMM_PART_CAL_4
(
3
);
MACE_GEMM_PART_CAL_4
(
4
);
MACE_GEMM_PART_CAL_4
(
5
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
vst1q_f32
(
c_ptr
+
2
*
stride_c
,
c2
);
vst1q_f32
(
c_ptr
+
3
*
stride_c
,
c3
);
vst1q_f32
(
c_ptr
+
4
*
stride_c
,
c4
);
vst1q_f32
(
c_ptr
+
5
*
stride_c
,
c5
);
#else
GemmBlock
(
a_ptr
,
b_ptr
,
6
,
4
,
4
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
#endif
}
inline
void
GemmX44
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
,
int
row
)
{
switch
(
row
)
{
case
1
:
Gemm144
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
2
:
Gemm244
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
3
:
Gemm344
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
4
:
Gemm444
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
5
:
Gemm544
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
6
:
Gemm644
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
...
...
@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr,
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
c7
=
vld1q_f32
(
c_ptr
+
7
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL
(
7
,
14
,
15
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL_8
(
6
,
12
,
13
);
MACE_GEMM_PART_CAL_8
(
7
,
14
,
15
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr,
c0
=
vld1q_f32
(
c_ptr
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
vst1q_f32
(
c_ptr
,
c0
);
#else
...
...
@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr,
c0
=
vld1q_f32
(
c_ptr
);
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr,
c1
=
vld1q_f32
(
c_ptr
+
1
*
stride_c
);
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr,
c2
=
vld1q_f32
(
c_ptr
+
2
*
stride_c
);
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr,
c3
=
vld1q_f32
(
c_ptr
+
3
*
stride_c
);
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr,
c4
=
vld1q_f32
(
c_ptr
+
4
*
stride_c
);
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr,
c5
=
vld1q_f32
(
c_ptr
+
5
*
stride_c
);
c6
=
vld1q_f32
(
c_ptr
+
6
*
stride_c
);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#else
MACE_GEMM_PART_CAL
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL
(
6
,
12
,
13
);
#endif
MACE_GEMM_PART_CAL_8
(
0
,
0
,
1
);
MACE_GEMM_PART_CAL_8
(
1
,
2
,
3
);
MACE_GEMM_PART_CAL_8
(
2
,
4
,
5
);
MACE_GEMM_PART_CAL_8
(
3
,
6
,
7
);
MACE_GEMM_PART_CAL_8
(
4
,
8
,
9
);
MACE_GEMM_PART_CAL_8
(
5
,
10
,
11
);
MACE_GEMM_PART_CAL_8
(
6
,
12
,
13
);
vst1q_f32
(
c_ptr
,
c0
);
vst1q_f32
(
c_ptr
+
1
*
stride_c
,
c1
);
...
...
@@ -589,9 +806,19 @@ inline void GemmTile(const float *A,
const
index_t
stride_c
,
float
*
C
)
{
#if defined(MACE_ENABLE_NEON)
index_t
h
,
w
,
k
;
for
(
h
=
0
;
h
<
height
-
7
;
h
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
index_t
h
=
0
;
index_t
w
=
0
;
index_t
k
=
0
;
#if defined(__aarch64__)
int
reg_height_tile
=
8
;
int
reg_K_tile
=
8
;
#else
int
reg_height_tile
=
6
;
int
reg_K_tile
=
4
;
#endif
for
(
h
=
0
;
h
<
height
-
reg_height_tile
+
1
;
h
+=
reg_height_tile
)
{
for
(
k
=
0
;
k
<
K
-
reg_K_tile
+
1
;
k
+=
reg_K_tile
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
#if defined(__aarch64__) && defined(__clang__)
int
nw
=
width
>>
2
;
...
...
@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A,
w
=
(
width
>>
2
)
<<
2
;
}
#el
se // gcc || armv7a
#el
if defined(__aarch64__) // gcc
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
Gemm884
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
#endif // clang && armv8a
#else // armv7
int
nw
=
width
>>
2
;
if
(
nw
>
0
)
{
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
;
a0
=
vld1q_f32
(
a_ptr
);
a1
=
vld1q_f32
(
a_ptr
+
1
*
stride_a
);
a2
=
vld1q_f32
(
a_ptr
+
2
*
stride_a
);
a3
=
vld1q_f32
(
a_ptr
+
3
*
stride_a
);
a4
=
vld1q_f32
(
a_ptr
+
4
*
stride_a
);
a5
=
vld1q_f32
(
a_ptr
+
5
*
stride_a
);
const
float
*
b_ptr0
=
B
+
k
*
stride_b
;
const
float
*
b_ptr1
=
B
+
(
k
+
1
)
*
stride_b
;
const
float
*
b_ptr2
=
B
+
(
k
+
2
)
*
stride_b
;
const
float
*
b_ptr3
=
B
+
(
k
+
3
)
*
stride_b
;
float
*
c_ptr0
=
C
+
h
*
stride_c
;
float
*
c_ptr1
=
C
+
(
h
+
1
)
*
stride_c
;
float
*
c_ptr2
=
C
+
(
h
+
2
)
*
stride_c
;
float
*
c_ptr3
=
C
+
(
h
+
3
)
*
stride_c
;
float
*
c_ptr4
=
C
+
(
h
+
4
)
*
stride_c
;
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_c
;
asm
volatile
(
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"pld [%1, #128]
\n
"
"vld1.f32 {d16-d17}, [%1]
\n
"
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"0:
\n
"
"pld [%3, #128]
\n
"
"vld1.f32 {d20-d21}, [%3]
\n
"
"pld [%4, #128]
\n
"
"vld1.f32 {d22-d23}, [%4]
\n
"
"pld [%5, #128]
\n
"
"vld1.f32 {d24-d25}, [%5]
\n
"
"pld [%6, #128]
\n
"
"vld1.f32 {d26-d27}, [%6]
\n
"
"pld [%8, #128]
\n
"
"vld1.f32 {d14-d15}, [%8]!
\n
"
"vmla.f32 q8, q6, %e22[0]
\n
"
"vmla.f32 q9, q6, %e23[0]
\n
"
"vmla.f32 q10, q6, %e24[0]
\n
"
"vmla.f32 q11, q6, %e25[0]
\n
"
"vmla.f32 q12, q6, %e26[0]
\n
"
"vmla.f32 q13, q6, %e27[0]
\n
"
"pld [%9, #128]
\n
"
"vld1.f32 {d12-d13}, [%9]!
\n
"
"vmla.f32 q8, q7, %e22[1]
\n
"
"vmla.f32 q9, q7, %e23[1]
\n
"
"vmla.f32 q10, q7, %e24[1]
\n
"
"vmla.f32 q11, q7, %e25[1]
\n
"
"vmla.f32 q12, q7, %e26[1]
\n
"
"vmla.f32 q13, q7, %e27[1]
\n
"
"pld [%10, #128]
\n
"
"vld1.f32 {d14-d15}, [%10]!
\n
"
"vmla.f32 q8, q6, %f22[0]
\n
"
"vmla.f32 q9, q6, %f23[0]
\n
"
"vmla.f32 q10, q6, %f24[0]
\n
"
"vmla.f32 q11, q6, %f25[0]
\n
"
"vmla.f32 q12, q6, %f26[0]
\n
"
"vmla.f32 q13, q6, %f27[0]
\n
"
"vmla.f32 q8, q7, %f22[1]
\n
"
"vmla.f32 q9, q7, %f23[1]
\n
"
"vmla.f32 q10, q7, %f24[1]
\n
"
"vmla.f32 q11, q7, %f25[1]
\n
"
"vmla.f32 q12, q7, %f26[1]
\n
"
"vmla.f32 q13, q7, %f27[1]
\n
"
"vst1.f32 {d16-d17}, [%1]!
\n
"
"vst1.f32 {d18-d19}, [%2]!
\n
"
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"vst1.f32 {d20-d21}, [%3]!
\n
"
"vst1.f32 {d22-d23}, [%4]!
\n
"
"pld [%1, #128]
\n
"
"vld1.f32 {d16-d17}, [%1]
\n
"
"vst1.f32 {d24-d25}, [%5]!
\n
"
"vst1.f32 {d26-d27}, [%6]!
\n
"
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"subs %0, #1
\n
"
"bne 0b
\n
"
:
"=r"
(
nw
),
// 0
"=r"
(
c_ptr0
),
// 1
"=r"
(
c_ptr1
),
// 2
"=r"
(
c_ptr2
),
// 3
"=r"
(
c_ptr3
),
// 4
"=r"
(
c_ptr4
),
// 5
"=r"
(
c_ptr5
),
// 6
"=r"
(
b_ptr0
),
// 7
"=r"
(
b_ptr1
),
// 8
"=r"
(
b_ptr2
),
// 9
"=r"
(
b_ptr3
)
// 10
:
"0"
(
nw
),
// 11
"1"
(
c_ptr0
),
// 12
"2"
(
c_ptr1
),
// 13
"3"
(
c_ptr2
),
// 14
"4"
(
c_ptr3
),
// 15
"5"
(
c_ptr4
),
// 16
"6"
(
c_ptr5
),
// 17
"7"
(
b_ptr0
),
// 18
"8"
(
b_ptr1
),
// 19
"9"
(
b_ptr2
),
// 20
"10"
(
b_ptr3
),
// 21
"w"
(
a0
),
// 22
"w"
(
a1
),
// 23
"w"
(
a2
),
// 24
"w"
(
a3
),
// 25
"w"
(
a4
),
// 26
"w"
(
a5
)
// 27
:
"cc"
,
"memory"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
w
=
(
width
>>
2
)
<<
2
;
}
#endif
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
8
,
width
-
w
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
reg_height_tile
,
reg_K_tile
,
width
-
w
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
if
(
k
<
K
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
const
float
*
b_ptr
=
B
+
k
*
stride_b
;
float
*
c_ptr
=
C
+
h
*
stride_c
;
GemmBlock
(
a_ptr
,
b_ptr
,
8
,
K
-
k
,
width
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
reg_height_tile
,
K
-
k
,
width
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
}
if
(
h
<
height
)
{
index_t
remain_h
=
height
-
h
;
for
(
k
=
0
;
k
<
K
-
7
;
k
+=
8
)
{
for
(
k
=
0
;
k
<
K
-
reg_K_tile
;
k
+=
reg_K_tile
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
#if defined(__aarch64__)
GemmX84
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#else
GemmX44
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#endif
}
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
8
,
width
-
w
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
GemmBlock
(
a_ptr
,
b_ptr
,
remain_h
,
reg_K_tile
,
width
-
w
,
stride_a
,
stride_
b
,
stride_
c
,
c_ptr
);
}
}
if
(
k
<
K
)
{
...
...
mace/ops/shape_test.cc
浏览文件 @
a376a1b8
...
...
@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
std
::
vector
<
int32_t
>
expected_input_shape
(
input_shape
.
begin
(),
input_shape
.
end
());
if
(
!
expected_input_shape
.
empty
())
{
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
expected_input_shape
);
}
else
{
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"ExpectedOutput"
,
{},
{
0
});
...
...
mace/ops/strided_slice_test.cc
浏览文件 @
a376a1b8
...
...
@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
const
std
::
vector
<
float
>
&
output
)
{
OpsTestNet
net
;
net
.
AddInputFromArray
<
CPU
,
float
>
(
"Input"
,
input_shape
,
input
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"BeginIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
begin_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
input_shape
.
size
()},
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"EndIndices"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
end_indices
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
input_shape
.
size
()},
strides
);
net
.
AddInputFromArray
<
CPU
,
int32_t
>
(
"Strides"
,
{
static_cast
<
int32_t
>
(
input_shape
.
size
())},
strides
);
OpDefBuilder
(
"StridedSlice"
,
"StridedSliceOpTest"
)
.
Input
(
"Input"
)
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
a376a1b8
...
...
@@ -164,6 +164,7 @@ class TransformerRule(Enum):
TRANSFORM_BUFFER_IMAGE
=
17
ADD_DEVICE_AND_DATA_TYPE
=
18
SORT_BY_EXECUTION
=
19
ADD_IN_OUT_TENSOR_INFO
=
20
class
ConverterInterface
(
object
):
...
...
@@ -210,6 +211,7 @@ class ConverterOption(object):
self
.
_device
=
DeviceType
.
CPU
.
value
self
.
_winograd_enabled
=
False
self
.
_transformer_option
=
[
TransformerRule
.
ADD_IN_OUT_TENSOR_INFO
,
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
,
TransformerRule
.
REMOVE_IDENTITY_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
...
...
mace/python/tools/converter_tool/tensorflow_converter.py
浏览文件 @
a376a1b8
...
...
@@ -166,6 +166,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self
.
_option
=
option
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
FilterFormat
.
HWIO
)
# import tensorflow graph
tf_graph_def
=
tf
.
GraphDef
()
with
tf
.
gfile
.
Open
(
src_model_file
,
'rb'
)
as
f
:
tf_graph_def
.
ParseFromString
(
f
.
read
())
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
a376a1b8
...
...
@@ -55,6 +55,7 @@ class Transformer(base_converter.ConverterInterface):
def
__init__
(
self
,
option
,
model
):
# DO NOT reorder the following transformers' order
self
.
_registered_transformers_order
=
[
TransformerRule
.
ADD_IN_OUT_TENSOR_INFO
,
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
,
TransformerRule
.
REMOVE_IDENTITY_OP
,
TransformerRule
.
TRANSFORM_GLOBAL_POOLING
,
...
...
@@ -78,6 +79,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule
.
SORT_BY_EXECUTION
,
]
self
.
_registered_transformers
=
{
TransformerRule
.
ADD_IN_OUT_TENSOR_INFO
:
self
.
add_in_out_tensor_info
,
TransformerRule
.
REMOVE_USELESS_RESHAPE_OP
:
self
.
remove_useless_reshape_op
,
TransformerRule
.
REMOVE_IDENTITY_OP
:
self
.
remove_identity_op
,
...
...
@@ -271,6 +274,21 @@ class Transformer(base_converter.ConverterInterface):
self
.
_model
.
op
.
remove
(
op
)
def
add_in_out_tensor_info
(
self
):
net
=
self
.
_model
for
input_node
in
self
.
_option
.
input_nodes
.
values
():
input_info
=
net
.
input_info
.
add
()
input_info
.
name
=
input_node
.
name
input_info
.
dims
.
extend
(
input_node
.
shape
)
for
output_node
in
self
.
_option
.
output_nodes
.
values
():
output_info
=
net
.
output_info
.
add
()
output_info
.
name
=
output_node
.
name
output_info
.
dims
.
extend
(
self
.
_producer
[
output_node
.
name
].
output_shape
[
0
].
dims
)
return
False
def
remove_useless_reshape_op
(
self
):
net
=
self
.
_model
for
op
in
net
.
op
:
...
...
mace/test/BUILD
浏览文件 @
a376a1b8
...
...
@@ -50,3 +50,24 @@ cc_test(
"@gtest//:gtest_main"
,
],
)
cc_test
(
name
=
"mace_api_exception_test"
,
testonly
=
1
,
srcs
=
[
"mace_api_exception_test.cc"
],
copts
=
[
"-Werror"
,
"-Wextra"
,
"-Wno-missing-field-initializers"
]
+
if_openmp_enabled
([
"-fopenmp"
])
+
if_neon_enabled
([
"-DMACE_ENABLE_NEON"
])
+
if_android_armv7
([
"-mfpu=neon"
])
+
if_android_armv7
([
"-mfloat-abi=softfp"
])
+
if_android
([
"-DMACE_ENABLE_OPENCL"
])
+
if_hexagon_enabled
([
"-DMACE_ENABLE_HEXAGON"
]),
linkopts
=
[
"-fopenmp"
],
linkstatic
=
1
,
deps
=
[
"//mace/ops:test"
,
"//mace/kernels:kernels"
,
"//mace/ops:ops"
,
"@gtest//:gtest_main"
,
],
)
mace/test/mace_api_exception_test.cc
0 → 100644
浏览文件 @
a376a1b8
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
test
{
TEST
(
MaceAPIExceptionTest
,
WrongInputTest
)
{
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
output_names
;
input_names
.
push_back
(
MakeString
(
"input"
,
0
));
output_names
.
push_back
(
MakeString
(
"output"
,
0
));
const
DeviceType
device
=
DeviceType
::
GPU
;
std
::
shared_ptr
<
NetDef
>
net_def
(
new
NetDef
());
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
InputInfo
*
info
=
net_def
->
add_input_info
();
info
->
set_name
(
input_names
[
i
]);
}
MaceEngine
engine
(
device
);
ASSERT_DEATH
(
engine
.
Init
(
net_def
.
get
(),
{
"input"
},
output_names
,
nullptr
),
""
);
}
}
// namespace test
}
// namespace mace
mace/test/mace_api_mt_test.cc
浏览文件 @
a376a1b8
...
...
@@ -298,6 +298,8 @@ void MaceRunFunc(const int in_out_size) {
{
mem_map
[
input_names
[
i
]]},
device
,
net_def
.
get
());
InputInfo
*
info
=
net_def
->
add_input_info
();
info
->
set_name
(
input_names
[
i
]);
}
BufferToImage
<
half
>
(
filter_tensor_name
,
filter_tensor_img_name
,
mace
::
kernels
::
CONV2D_FILTER
,
{},
device
,
...
...
@@ -315,6 +317,8 @@ void MaceRunFunc(const int in_out_size) {
mace
::
kernels
::
IN_OUT_CHANNEL
,
device
,
net_def
.
get
());
OutputInfo
*
info
=
net_def
->
add_output_info
();
info
->
set_name
(
output_names
[
i
]);
}
const
std
::
string
file_path
=
"/data/local/tmp/mace"
;
...
...
mace/test/mace_api_test.cc
浏览文件 @
a376a1b8
...
...
@@ -308,6 +308,8 @@ void MaceRun(const int in_out_size,
{
mem_map
[
input_names
[
i
]]},
device
,
net_def
.
get
());
InputInfo
*
info
=
net_def
->
add_input_info
();
info
->
set_name
(
input_names
[
i
]);
}
BufferToImage
<
half
>
(
filter_tensor_name
,
filter_tensor_img_name
,
mace
::
kernels
::
CONV2D_FILTER
,
{},
device
,
...
...
@@ -324,6 +326,8 @@ void MaceRun(const int in_out_size,
mace
::
kernels
::
IN_OUT_CHANNEL
,
device
,
net_def
.
get
());
OutputInfo
*
info
=
net_def
->
add_output_info
();
info
->
set_name
(
output_names
[
i
]);
}
MaceEngine
engine
(
device
);
...
...
@@ -376,5 +380,6 @@ TEST_F(MaceAPITest, GPUVariableInputShape) {
{{
1
,
16
,
32
,
16
},
{
1
,
32
,
64
,
16
}},
{
16
,
16
,
3
,
3
});
}
}
// namespace test
}
// namespace mace
mace/utils/utils.h
浏览文件 @
a376a1b8
...
...
@@ -16,6 +16,7 @@
#define MACE_UTILS_UTILS_H_
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
...
...
@@ -152,5 +153,14 @@ inline bool ReadBinaryFile(std::vector<unsigned char> *data,
return
true
;
}
template
<
typename
T
>
std
::
vector
<
std
::
string
>
MapKeys
(
const
std
::
map
<
std
::
string
,
T
>
&
data
)
{
std
::
vector
<
std
::
string
>
keys
;
for
(
auto
&
kv
:
data
)
{
keys
.
push_back
(
kv
.
first
);
}
return
keys
;
}
}
// namespace mace
#endif // MACE_UTILS_UTILS_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录