Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
8bb5716a
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8bb5716a
编写于
9月 18, 2017
作者:
L
Liangliang He
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'conv2d-neon' into 'master'
Neon conv2d 3x3 stride 2 kernel. See merge request !44
上级
291a5ee6
89b4b039
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
287 addition
and
209 deletion
+287
-209
mace/kernels/neon/conv_2d_neon.cc
mace/kernels/neon/conv_2d_neon.cc
+23
-15
mace/kernels/neon/conv_2d_neon_3x3.cc
mace/kernels/neon/conv_2d_neon_3x3.cc
+253
-186
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+5
-2
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+2
-2
mace/ops/pooling_test.cc
mace/ops/pooling_test.cc
+4
-4
未找到文件。
mace/kernels/neon/conv_2d_neon.cc
浏览文件 @
8bb5716a
...
...
@@ -22,6 +22,13 @@ extern void Conv2dNeonK3x3S1(const float *input,
float
*
output
,
const
index_t
*
output_shape
);
extern
void
Conv2dNeonK3x3S2
(
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
);
extern
void
Conv2dNeonK5x5S1
(
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
...
...
@@ -30,27 +37,25 @@ extern void Conv2dNeonK5x5S1(const float *input,
const
index_t
*
output_shape
);
template
<
>
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
// NCHW
void
Conv2dFunctor
<
DeviceType
::
NEON
,
float
>::
operator
()(
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
filter
,
const
index_t
*
filter_shape
,
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
)
{
typedef
void
(
*
Conv2dNeonFunction
)(
const
float
*
input
,
// NCHW
const
float
*
input
,
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
float
*
filter
,
const
float
*
bias
,
float
*
output
,
const
index_t
*
output_shape
);
// Selection matrix: kernel_size x stride_size
static
const
Conv2dNeonFunction
selector
[
5
][
2
]
=
{
{
Conv2dNeonK1x1S1
,
nullptr
},
{
nullptr
,
nullptr
},
{
Conv2dNeonK3x3S1
,
nullptr
},
{
Conv2dNeonK3x3S1
,
Conv2dNeonK3x3S2
},
{
nullptr
,
nullptr
},
{
Conv2dNeonK5x5S1
,
nullptr
}};
// not implement yet
...
...
@@ -59,7 +64,10 @@ operator()(const float *input, // NCHW
if
(
kernel_h
!=
kernel_w
||
kernel_h
>
5
||
strides_
[
0
]
!=
strides_
[
1
]
||
strides_
[
0
]
>
2
||
dilations_
[
0
]
!=
1
||
dilations_
[
1
]
!=
1
||
selector
[
kernel_h
-
1
][
strides_
[
0
]
-
1
]
==
nullptr
)
{
LOG
(
WARNING
)
<<
"NEON conv2d kernel not implementated, using slow vesion"
;
LOG
(
WARNING
)
<<
"NEON conv2d kernel with "
<<
"filter"
<<
kernel_h
<<
"x"
<<
kernel_w
<<
","
<<
" stride "
<<
strides_
[
0
]
<<
"x"
<<
strides_
[
1
]
<<
" is not implemented yet, using slow version"
;
Conv2dFunctor
<
DeviceType
::
CPU
,
float
>
(
strides_
,
paddings_
,
dilations_
)(
input
,
input_shape
,
filter
,
filter_shape
,
bias
,
output
,
output_shape
);
return
;
...
...
mace/kernels/neon/conv_2d_neon_3x3.cc
浏览文件 @
8bb5716a
...
...
@@ -8,221 +8,288 @@
namespace
mace
{
namespace
kernels
{
#define KERNEL_HEAD_CODE \
int output_batch = output_shape[0]; \
int output_channels = output_shape[1]; \
int output_height = output_shape[2]; \
int output_width = output_shape[3]; \
int input_batch = input_shape[0]; \
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int kernel_h = 3; \
int kernel_w = 3; \
for (int b = 0; b < output_batch; ++b) { \
float* output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \
const float* input_ptr = input + b * input_channels * input_height * input_width; \
float* output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
for (int ic = 0; ic < input_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += 9; \
input_ptr += input_height * input_width; \
} \
} \
}
static
const
int
kRegisterSize
=
4
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
batch
=
output_shape
[
0
];
int
channels
=
output_shape
[
1
];
int
height
=
output_shape
[
2
];
int
width
=
output_shape
[
3
];
int
input_batch
=
input_shape
[
0
];
int
input_channels
=
input_shape
[
1
];
int
input_height
=
input_shape
[
2
];
int
input_width
=
input_shape
[
3
];
int
kernel_h
=
3
;
int
kernel_w
=
3
;
int
height_count
=
(
height
>>
1
)
<<
1
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
float
*
output_ptr_base
=
output
+
b
*
channels
*
height
*
width
;
for
(
int
oc
=
0
;
oc
<
channels
;
++
oc
)
{
const
float
*
filter_ptr
=
filter
+
oc
*
input_channels
*
kernel_h
*
kernel_w
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
height
*
width
;
std
::
fill
(
output_ptr
,
output_ptr
+
height
*
width
,
bias
[
oc
]);
for
(
int
ic
=
0
;
ic
<
input_channels
;
++
ic
)
{
float32x4_t
filter0
=
vld1q_f32
(
filter_ptr
);
float32x4_t
filter3
=
vld1q_f32
(
filter_ptr
+
3
);
float32x4_t
filter6
=
vld1q_f32
(
filter_ptr
+
6
);
const
float
*
row
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
};
float
*
output_ptr1
=
output_ptr
;
float
*
output_ptr2
=
output_ptr
+
width
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
height_count
=
(
output_shape
[
2
]
>>
1
)
<<
1
;
KERNEL_HEAD_CODE
const
float
*
row_ptr_v
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
};
float
*
output_ptr_v
[]
=
{
output_ptr
,
output_ptr
+
output_width
};
for
(
int
h
=
0
;
h
<
height_count
;
h
+=
2
)
{
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
int
count
=
output_
width
>>
2
;
int
remain_count
=
output_
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum1
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
// 0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
// 4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter6
,
2
);
float32x4_t
n_sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
n_row_former
=
vld1q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
kRegisterSize
);
float32x4_t
n_row_ext0
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
float32x4_t
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_former
,
n_filter_v
[
0
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext0
,
n_filter_v
[
0
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext1
,
n_filter_v
[
0
],
2
);
float32x4_t
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
float32x4_t
n_row1_ext0
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
float32x4_t
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_former
,
n_filter_v
[
1
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext0
,
n_filter_v
[
1
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext1
,
n_filter_v
[
1
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
2
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
kRegisterSize
);
n_row_ext0
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_former
,
n_filter_v
[
2
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext0
,
n_filter_v
[
2
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext1
,
n_filter_v
[
2
],
2
);
// second row
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter0
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter0
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_2
,
filter0
,
2
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_0
,
filter3
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_1
,
filter3
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_2
,
filter3
,
2
);
row1_ext_0
=
vld1q_f32
(
row
[
3
]);
// 0123
row1_latter
=
vld1q_f32
(
row
[
3
]
+
kRegisterSize
);
// 4567
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter6
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter6
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_2
,
filter6
,
2
);
float32x4_t
output_row0
=
vld1q_f32
(
output_ptr1
);
float32x4_t
output_row1
=
vld1q_f32
(
output_ptr2
);
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
output_row1
=
vaddq_f32
(
output_row1
,
sum1
);
vst1q_f32
(
output_ptr1
,
output_row0
);
vst1q_f32
(
output_ptr
2
,
output_row1
);
output_ptr
1
+=
kRegisterSize
;
output_ptr
2
+=
kRegisterSize
;
float32x4_t
n_sum1
=
vdupq_n_f32
(
.0
f
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_former
,
n_filter_v
[
0
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext0
,
n_filter_v
[
0
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext1
,
n_filter_v
[
0
],
2
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_former
,
n_filter_v
[
1
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_ext0
,
n_filter_v
[
1
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_ext1
,
n_filter_v
[
1
],
2
);
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
3
]);
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
3
]
+
kRegisterSize
);
n_row1_ext0
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_former
,
n_filter_v
[
2
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext0
,
n_filter_v
[
2
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext1
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_v
[
0
]
);
float32x4_t
n_output_row1
=
vld1q_f32
(
output_ptr_v
[
1
]
);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum0
);
n_output_row1
=
vaddq_f32
(
n_output_row1
,
n_sum1
);
vst1q_f32
(
output_ptr
_v
[
0
],
n_output_row
);
vst1q_f32
(
output_ptr_v
[
1
],
n_output_row1
);
output_ptr
_v
[
0
]
+=
kRegisterSize
;
output_ptr
_v
[
1
]
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
_ptr_v
[
i
]
+=
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
// 0123
float32x4_t
row3
=
vld1q_f32
(
row
[
3
]);
// 0123
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr1
,
sum
,
3
);
*
output_ptr1
=
vaddvq_f32
(
sum
);
sum
=
vmulq_f32
(
row1
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row3
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr2
,
sum
,
3
);
*
output_ptr2
=
vaddvq_f32
(
sum
);
++
output_ptr1
;
++
output_ptr2
;
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum0
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum0
=
vsetq_lane_f32
(
*
output_ptr_v
[
0
],
n_sum0
,
3
);
*
output_ptr_v
[
0
]
=
vaddvq_f32
(
n_sum0
);
float32x4_t
n_row3
=
vld1q_f32
(
row_ptr_v
[
3
]);
float32x4_t
n_sum1
=
vmulq_f32
(
n_row_v
[
1
],
n_filter_v
[
0
]);
n_sum1
=
vmlaq_f32
(
n_sum1
,
n_row_v
[
2
],
n_filter_v
[
1
]);
n_sum1
=
vmlaq_f32
(
n_sum1
,
n_row3
,
n_filter_v
[
2
]);
n_sum1
=
vsetq_lane_f32
(
*
output_ptr_v
[
1
],
n_sum1
,
3
);
*
output_ptr_v
[
1
]
=
vaddvq_f32
(
n_sum1
);
++
output_ptr_v
[
0
];
++
output_ptr_v
[
1
];
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
1
;
row
_ptr_v
[
i
]
+=
1
;
}
}
output_ptr
1
+=
width
;
output_ptr
2
+=
width
;
output_ptr
_v
[
0
]
+=
output_
width
;
output_ptr
_v
[
1
]
+=
output_
width
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
2
+
input_width
;
row
_ptr_v
[
i
]
+=
2
+
input_width
;
}
}
if
(
height
!=
height_count
)
{
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
if
(
output_
height
!=
height_count
)
{
int
count
=
output_
width
>>
2
;
int
remain_count
=
output_
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
// 0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
// 4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter6
,
2
);
float32x4_t
output_row0
=
vld1q_f32
(
output_ptr1
);
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
vst1q_f32
(
output_ptr1
,
output_row0
);
output_ptr1
+=
kRegisterSize
;
float32x4_t
n_sum
=
vdupq_n_f32
(
.0
f
);
float32x4_t
n_row_former
=
vld1q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
kRegisterSize
);
float32x4_t
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
float32x4_t
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
0
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
0
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
0
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
1
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
1
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
2
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
kRegisterSize
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
2
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_v
[
0
]);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum
);
vst1q_f32
(
output_ptr_v
[
0
],
n_output_row
);
output_ptr_v
[
0
]
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
_ptr_v
[
i
]
+=
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
// 0123
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
]),
};
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum
=
vsetq_lane_f32
(
*
output_ptr_v
[
0
],
n_sum
,
3
);
*
output_ptr_v
[
0
]
=
vaddvq_f32
(
n_sum
);
++
output_ptr_v
[
0
];
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
1
;
}
}
}
KERNEL_TAIL_CODE
}
void
Conv2dNeonK3x3S2
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
tail_step
=
2
*
(
input_shape
[
3
]
-
output_shape
[
3
]);
KERNEL_HEAD_CODE
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr1
,
sum
,
3
);
*
output_ptr1
=
vaddvq_f32
(
sum
);
const
float
*
row_ptr_v
[
3
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
};
++
output_ptr1
;
float
*
output_ptr_inner
=
output_ptr
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
count
=
output_width
>>
2
;
int
remain_count
=
output_width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
n_sum
=
vdupq_n_f32
(
.0
f
);
float32x4x2_t
n_row_former
=
vld2q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
8
);
float32x4_t
n_row_ext
=
vextq_f32
(
n_row_former
.
val
[
0
],
n_row_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
0
],
n_filter_v
[
0
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
1
],
n_filter_v
[
0
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext
,
n_filter_v
[
0
],
2
);
float32x4x2_t
n_row1_former
=
vld2q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
8
);
float32x4_t
n_row1_ext
=
vextq_f32
(
n_row1_former
.
val
[
0
],
n_row1_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
0
],
n_filter_v
[
1
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
1
],
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_ext
,
n_filter_v
[
1
],
2
);
float32x4x2_t
n_row2_former
=
vld2q_f32
(
row_ptr_v
[
2
]);
float32x4_t
n_row2_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
8
);
float32x4_t
n_row2_ext
=
vextq_f32
(
n_row2_former
.
val
[
0
],
n_row2_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
0
],
n_filter_v
[
2
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
1
],
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_ext
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_inner
);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum
);
vst1q_f32
(
output_ptr_inner
,
n_output_row
);
output_ptr_inner
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
1
;
row
_ptr_v
[
i
]
+=
2
*
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum
=
vsetq_lane_f32
(
*
output_ptr_inner
,
n_sum
,
3
);
*
output_ptr_inner
=
vaddvq_f32
(
n_sum
);
++
output_ptr_inner
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
2
;
}
filter_ptr
+=
9
;
input_ptr
+=
input_height
*
input_width
;
}
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
tail_step
;
}
}
KERNEL_TAIL_CODE
}
#undef KERNEL_HEAD_CODE
#undef KERNEL_TAIL_CODE
}
// namespace kernels
}
// namespace mace
mace/ops/conv_2d_benchmark.cc
浏览文件 @
8bb5716a
...
...
@@ -61,8 +61,7 @@ static void Conv2d(int iters,
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, \
OC); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
...
...
@@ -77,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
2
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
2
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
2
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
2
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
31
,
5
,
5
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
SAME
,
128
,
float
);
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
8bb5716a
...
...
@@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// generate random input
index_t
batch
=
1
+
rand
()
%
10
;
index_t
input_channels
=
1
+
rand
()
%
50
;
index_t
height
=
7
+
rand
()
%
100
;
index_t
width
=
7
+
rand
()
%
100
;
index_t
height
=
11
+
rand
()
%
100
;
index_t
width
=
11
+
rand
()
%
100
;
index_t
output_channels
=
1
+
rand
()
%
50
;
// Construct graph
auto
&
net
=
test_net
();
...
...
mace/ops/pooling_test.cc
浏览文件 @
8bb5716a
...
...
@@ -155,9 +155,9 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
6
,
8
,
9
,
16
,
18
,
19
});
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
6
,
8
,
9
,
16
,
18
,
19
});
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
TEST_F
(
PoolingOpTest
,
MAX_k3x3s2x2
)
{
...
...
@@ -183,7 +183,7 @@ TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
net
.
RunOp
(
DeviceType
::
NEON
);
// Check
Tensor
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
11
,
13
,
14
,
16
,
18
,
19
});
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
11
,
13
,
14
,
16
,
18
,
19
});
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
0.001
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录