Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
328839bd
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
328839bd
编写于
9月 15, 2017
作者:
L
liuqi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Finish conv2d 3x3 stride 2 neon kernel.
上级
d20d5ad8
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
259 addition
and
188 deletion
+259
-188
mace/kernels/neon/conv_2d_neon_3x3.cc
mace/kernels/neon/conv_2d_neon_3x3.cc
+253
-186
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+4
-0
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+2
-2
未找到文件。
mace/kernels/neon/conv_2d_neon_3x3.cc
浏览文件 @
328839bd
...
...
@@ -8,221 +8,288 @@
namespace
mace
{
namespace
kernels
{
#define KERNEL_HEAD_CODE \
int output_batch = output_shape[0]; \
int output_channels = output_shape[1]; \
int output_height = output_shape[2]; \
int output_width = output_shape[3]; \
int input_batch = input_shape[0]; \
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int kernel_h = 3; \
int kernel_w = 3; \
for (int b = 0; b < output_batch; ++b) { \
float* output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \
const float* input_ptr = input + b * input_channels * input_height * input_width; \
float* output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
for (int ic = 0; ic < input_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += 9; \
input_ptr += input_height * input_width; \
} \
} \
}
static
const
int
kRegisterSize
=
4
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
batch
=
output_shape
[
0
];
int
channels
=
output_shape
[
1
];
int
height
=
output_shape
[
2
];
int
width
=
output_shape
[
3
];
int
input_batch
=
input_shape
[
0
];
int
input_channels
=
input_shape
[
1
];
int
input_height
=
input_shape
[
2
];
int
input_width
=
input_shape
[
3
];
int
kernel_h
=
3
;
int
kernel_w
=
3
;
int
height_count
=
(
height
>>
1
)
<<
1
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
float
*
output_ptr_base
=
output
+
b
*
channels
*
height
*
width
;
for
(
int
oc
=
0
;
oc
<
channels
;
++
oc
)
{
const
float
*
filter_ptr
=
filter
+
oc
*
input_channels
*
kernel_h
*
kernel_w
;
const
float
*
input_ptr
=
input
+
b
*
input_channels
*
input_height
*
input_width
;
float
*
output_ptr
=
output_ptr_base
+
oc
*
height
*
width
;
std
::
fill
(
output_ptr
,
output_ptr
+
height
*
width
,
bias
[
oc
]);
for
(
int
ic
=
0
;
ic
<
input_channels
;
++
ic
)
{
float32x4_t
filter0
=
vld1q_f32
(
filter_ptr
);
float32x4_t
filter3
=
vld1q_f32
(
filter_ptr
+
3
);
float32x4_t
filter6
=
vld1q_f32
(
filter_ptr
+
6
);
const
float
*
row
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
};
float
*
output_ptr1
=
output_ptr
;
float
*
output_ptr2
=
output_ptr
+
width
;
void
Conv2dNeonK3x3S1
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
height_count
=
(
output_shape
[
2
]
>>
1
)
<<
1
;
KERNEL_HEAD_CODE
const
float
*
row_ptr_v
[
kRegisterSize
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
,
input_ptr
+
3
*
input_width
};
float
*
output_ptr_v
[]
=
{
output_ptr
,
output_ptr
+
output_width
};
for
(
int
h
=
0
;
h
<
height_count
;
h
+=
2
)
{
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
int
count
=
output_
width
>>
2
;
int
remain_count
=
output_
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
sum1
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
// 0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
// 4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter6
,
2
);
float32x4_t
n_sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
n_row_former
=
vld1q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
kRegisterSize
);
float32x4_t
n_row_ext0
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
float32x4_t
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_former
,
n_filter_v
[
0
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext0
,
n_filter_v
[
0
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext1
,
n_filter_v
[
0
],
2
);
float32x4_t
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
float32x4_t
n_row1_ext0
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
float32x4_t
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_former
,
n_filter_v
[
1
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext0
,
n_filter_v
[
1
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row1_ext1
,
n_filter_v
[
1
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
2
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
kRegisterSize
);
n_row_ext0
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_former
,
n_filter_v
[
2
],
0
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext0
,
n_filter_v
[
2
],
1
);
n_sum0
=
vfmaq_laneq_f32
(
n_sum0
,
n_row_ext1
,
n_filter_v
[
2
],
2
);
// second row
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter0
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter0
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_2
,
filter0
,
2
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_0
,
filter3
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_1
,
filter3
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row0_ext_2
,
filter3
,
2
);
row1_ext_0
=
vld1q_f32
(
row
[
3
]);
// 0123
row1_latter
=
vld1q_f32
(
row
[
3
]
+
kRegisterSize
);
// 4567
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_0
,
filter6
,
0
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_1
,
filter6
,
1
);
sum1
=
vfmaq_laneq_f32
(
sum1
,
row1_ext_2
,
filter6
,
2
);
float32x4_t
output_row0
=
vld1q_f32
(
output_ptr1
);
float32x4_t
output_row1
=
vld1q_f32
(
output_ptr2
);
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
output_row1
=
vaddq_f32
(
output_row1
,
sum1
);
vst1q_f32
(
output_ptr1
,
output_row0
);
vst1q_f32
(
output_ptr
2
,
output_row1
);
output_ptr
1
+=
kRegisterSize
;
output_ptr
2
+=
kRegisterSize
;
float32x4_t
n_sum1
=
vdupq_n_f32
(
.0
f
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_former
,
n_filter_v
[
0
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext0
,
n_filter_v
[
0
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext1
,
n_filter_v
[
0
],
2
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_former
,
n_filter_v
[
1
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_ext0
,
n_filter_v
[
1
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row_ext1
,
n_filter_v
[
1
],
2
);
n_row1_former
=
vld1q_f32
(
row_ptr_v
[
3
]);
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
3
]
+
kRegisterSize
);
n_row1_ext0
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
1
);
n_row1_ext1
=
vextq_f32
(
n_row1_former
,
n_row1_latter
,
2
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_former
,
n_filter_v
[
2
],
0
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext0
,
n_filter_v
[
2
],
1
);
n_sum1
=
vfmaq_laneq_f32
(
n_sum1
,
n_row1_ext1
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_v
[
0
]
);
float32x4_t
n_output_row1
=
vld1q_f32
(
output_ptr_v
[
1
]
);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum0
);
n_output_row1
=
vaddq_f32
(
n_output_row1
,
n_sum1
);
vst1q_f32
(
output_ptr
_v
[
0
],
n_output_row
);
vst1q_f32
(
output_ptr_v
[
1
],
n_output_row1
);
output_ptr
_v
[
0
]
+=
kRegisterSize
;
output_ptr
_v
[
1
]
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
_ptr_v
[
i
]
+=
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
// 0123
float32x4_t
row3
=
vld1q_f32
(
row
[
3
]);
// 0123
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr1
,
sum
,
3
);
*
output_ptr1
=
vaddvq_f32
(
sum
);
sum
=
vmulq_f32
(
row1
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row3
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr2
,
sum
,
3
);
*
output_ptr2
=
vaddvq_f32
(
sum
);
++
output_ptr1
;
++
output_ptr2
;
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum0
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum0
=
vmlaq_f32
(
n_sum0
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum0
=
vsetq_lane_f32
(
*
output_ptr_v
[
0
],
n_sum0
,
3
);
*
output_ptr_v
[
0
]
=
vaddvq_f32
(
n_sum0
);
float32x4_t
n_row3
=
vld1q_f32
(
row_ptr_v
[
3
]);
float32x4_t
n_sum1
=
vmulq_f32
(
n_row_v
[
1
],
n_filter_v
[
0
]);
n_sum1
=
vmlaq_f32
(
n_sum1
,
n_row_v
[
2
],
n_filter_v
[
1
]);
n_sum1
=
vmlaq_f32
(
n_sum1
,
n_row3
,
n_filter_v
[
2
]);
n_sum1
=
vsetq_lane_f32
(
*
output_ptr_v
[
1
],
n_sum1
,
3
);
*
output_ptr_v
[
1
]
=
vaddvq_f32
(
n_sum1
);
++
output_ptr_v
[
0
];
++
output_ptr_v
[
1
];
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
1
;
row
_ptr_v
[
i
]
+=
1
;
}
}
output_ptr
1
+=
width
;
output_ptr
2
+=
width
;
output_ptr
_v
[
0
]
+=
output_
width
;
output_ptr
_v
[
1
]
+=
output_
width
;
for
(
int
i
=
0
;
i
<
kRegisterSize
;
++
i
)
{
row
[
i
]
+=
2
+
input_width
;
row
_ptr_v
[
i
]
+=
2
+
input_width
;
}
}
if
(
height
!=
height_count
)
{
int
count
=
width
>>
2
;
int
remain_count
=
width
&
3
;
if
(
output_
height
!=
height_count
)
{
int
count
=
output_
width
>>
2
;
int
remain_count
=
output_
width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
sum0
=
vdupq_n_f32
(
.0
f
);
float32x4_t
row0_ext_0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row0_latter
=
vld1q_f32
(
row
[
0
]
+
kRegisterSize
);
// 4567
float32x4_t
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
float32x4_t
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter0
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter0
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter0
,
2
);
float32x4_t
row1_ext_0
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row1_latter
=
vld1q_f32
(
row
[
1
]
+
kRegisterSize
);
// 4567
float32x4_t
row1_ext_1
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
1
);
// 1234
float32x4_t
row1_ext_2
=
vextq_f32
(
row1_ext_0
,
row1_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_0
,
filter3
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_1
,
filter3
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row1_ext_2
,
filter3
,
2
);
row0_ext_0
=
vld1q_f32
(
row
[
2
]);
// 0123
row0_latter
=
vld1q_f32
(
row
[
2
]
+
kRegisterSize
);
// 4567
row0_ext_1
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
1
);
// 1234
row0_ext_2
=
vextq_f32
(
row0_ext_0
,
row0_latter
,
2
);
// 2345
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_0
,
filter6
,
0
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_1
,
filter6
,
1
);
sum0
=
vfmaq_laneq_f32
(
sum0
,
row0_ext_2
,
filter6
,
2
);
float32x4_t
output_row0
=
vld1q_f32
(
output_ptr1
);
output_row0
=
vaddq_f32
(
output_row0
,
sum0
);
vst1q_f32
(
output_ptr1
,
output_row0
);
output_ptr1
+=
kRegisterSize
;
float32x4_t
n_sum
=
vdupq_n_f32
(
.0
f
);
float32x4_t
n_row_former
=
vld1q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
kRegisterSize
);
float32x4_t
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
float32x4_t
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
0
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
0
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
0
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
1
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
kRegisterSize
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
1
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
1
],
2
);
n_row_former
=
vld1q_f32
(
row_ptr_v
[
2
]);
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
kRegisterSize
);
n_row_ext1
=
vextq_f32
(
n_row_former
,
n_row_latter
,
1
);
n_row_ext2
=
vextq_f32
(
n_row_former
,
n_row_latter
,
2
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
,
n_filter_v
[
2
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext1
,
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext2
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_v
[
0
]);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum
);
vst1q_f32
(
output_ptr_v
[
0
],
n_output_row
);
output_ptr_v
[
0
]
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
kRegisterSize
;
row
_ptr_v
[
i
]
+=
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
row0
=
vld1q_f32
(
row
[
0
]);
// 0123
float32x4_t
row1
=
vld1q_f32
(
row
[
1
]);
// 0123
float32x4_t
row2
=
vld1q_f32
(
row
[
2
]);
// 0123
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
]),
};
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum
=
vsetq_lane_f32
(
*
output_ptr_v
[
0
],
n_sum
,
3
);
*
output_ptr_v
[
0
]
=
vaddvq_f32
(
n_sum
);
++
output_ptr_v
[
0
];
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
1
;
}
}
}
KERNEL_TAIL_CODE
}
void
Conv2dNeonK3x3S2
(
const
float
*
input
,
// NCHW
const
index_t
*
input_shape
,
const
float
*
filter
,
// c_out, c_in, kernel_h, kernel_w
const
float
*
bias
,
// c_out
float
*
output
,
// NCHW
const
index_t
*
output_shape
)
{
int
tail_step
=
2
*
(
input_shape
[
3
]
-
output_shape
[
3
]);
float32x4_t
sum
=
vmulq_f32
(
row0
,
filter0
);
sum
=
vmlaq_f32
(
sum
,
row1
,
filter3
);
sum
=
vmlaq_f32
(
sum
,
row2
,
filter6
);
sum
=
vsetq_lane_f32
(
*
output_ptr1
,
sum
,
3
);
*
output_ptr1
=
vaddvq_f32
(
sum
);
KERNEL_HEAD_CODE
++
output_ptr1
;
const
float
*
row_ptr_v
[
3
]
=
{
input_ptr
,
input_ptr
+
input_width
,
input_ptr
+
2
*
input_width
};
float
*
output_ptr_inner
=
output_ptr
;
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
count
=
output_width
>>
2
;
int
remain_count
=
output_width
&
3
;
for
(;
count
>
0
;
--
count
)
{
float32x4_t
n_sum
=
vdupq_n_f32
(
.0
f
);
float32x4x2_t
n_row_former
=
vld2q_f32
(
row_ptr_v
[
0
]);
float32x4_t
n_row_latter
=
vld1q_f32
(
row_ptr_v
[
0
]
+
8
);
float32x4_t
n_row_ext
=
vextq_f32
(
n_row_former
.
val
[
0
],
n_row_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
0
],
n_filter_v
[
0
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_former
.
val
[
1
],
n_filter_v
[
0
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row_ext
,
n_filter_v
[
0
],
2
);
float32x4x2_t
n_row1_former
=
vld2q_f32
(
row_ptr_v
[
1
]);
float32x4_t
n_row1_latter
=
vld1q_f32
(
row_ptr_v
[
1
]
+
8
);
float32x4_t
n_row1_ext
=
vextq_f32
(
n_row1_former
.
val
[
0
],
n_row1_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
0
],
n_filter_v
[
1
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_former
.
val
[
1
],
n_filter_v
[
1
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row1_ext
,
n_filter_v
[
1
],
2
);
float32x4x2_t
n_row2_former
=
vld2q_f32
(
row_ptr_v
[
2
]);
float32x4_t
n_row2_latter
=
vld1q_f32
(
row_ptr_v
[
2
]
+
8
);
float32x4_t
n_row2_ext
=
vextq_f32
(
n_row2_former
.
val
[
0
],
n_row2_latter
,
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
0
],
n_filter_v
[
2
],
0
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_former
.
val
[
1
],
n_filter_v
[
2
],
1
);
n_sum
=
vfmaq_laneq_f32
(
n_sum
,
n_row2_ext
,
n_filter_v
[
2
],
2
);
float32x4_t
n_output_row
=
vld1q_f32
(
output_ptr_inner
);
n_output_row
=
vaddq_f32
(
n_output_row
,
n_sum
);
vst1q_f32
(
output_ptr_inner
,
n_output_row
);
output_ptr_inner
+=
kRegisterSize
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
2
*
kRegisterSize
;
}
}
for
(;
remain_count
>
0
;
--
remain_count
)
{
float32x4_t
n_row_v
[]
=
{
vld1q_f32
(
row_ptr_v
[
0
]),
vld1q_f32
(
row_ptr_v
[
1
]),
vld1q_f32
(
row_ptr_v
[
2
])
};
float32x4_t
n_sum
=
vmulq_f32
(
n_row_v
[
0
],
n_filter_v
[
0
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
1
],
n_filter_v
[
1
]);
n_sum
=
vmlaq_f32
(
n_sum
,
n_row_v
[
2
],
n_filter_v
[
2
]);
n_sum
=
vsetq_lane_f32
(
*
output_ptr_inner
,
n_sum
,
3
);
*
output_ptr_inner
=
vaddvq_f32
(
n_sum
);
++
output_ptr_inner
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row
[
i
]
+=
1
;
row
_ptr_v
[
i
]
+=
2
;
}
}
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
row_ptr_v
[
i
]
+=
tail_step
;
}
}
filter_ptr
+=
9
;
input_ptr
+=
input_height
*
input_width
;
}
}
}
KERNEL_TAIL_CODE
}
}
// namespace kernels
}
// namespace mace
#undef KERNEL_HEAD_CODE
#undef KERNEL_TAIL_CODE
}
// namespace kernels
}
// namespace mace
mace/ops/conv_2d_benchmark.cc
浏览文件 @
328839bd
...
...
@@ -76,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
1
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
2
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
2
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
3
,
3
,
2
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
33
,
31
,
3
,
3
,
2
,
SAME
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
31
,
5
,
5
,
1
,
VALID
,
128
,
float
);
BM_CONV_2D
(
1
,
64
,
32
,
32
,
5
,
5
,
1
,
SAME
,
128
,
float
);
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
328839bd
...
...
@@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// generate random input
index_t
batch
=
1
+
rand
()
%
10
;
index_t
input_channels
=
1
+
rand
()
%
50
;
index_t
height
=
7
+
rand
()
%
100
;
index_t
width
=
7
+
rand
()
%
100
;
index_t
height
=
11
+
rand
()
%
100
;
index_t
width
=
11
+
rand
()
%
100
;
index_t
output_channels
=
1
+
rand
()
%
50
;
// Construct graph
auto
&
net
=
test_net
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录