提交 d3ccb99b 编写于 作者: 李寅

Merge branch 'conv7x7_v8' into 'master'

optimize conv7x7 s1 s2 s3 armv8 neon

See merge request !426
......@@ -21,7 +21,73 @@
namespace mace {
namespace kernels {
#define Conv2dNeonK7x7SnLoadCalc4 \
#define Conv2dArmv8NeonK7x7SnLoadCalc4 \
/* load filter (4 outch x 1 height x 4 width) */ \
float32x4_t vf00, vf01; \
float32x4_t vf10, vf11; \
float32x4_t vf20, vf21; \
float32x4_t vf30, vf31; \
vf00 = vld1q_f32(filter_ptr0); \
vf01 = vld1q_f32(filter_ptr0 + 4); \
vf10 = vld1q_f32(filter_ptr1); \
vf11 = vld1q_f32(filter_ptr1 + 4); \
vf20 = vld1q_f32(filter_ptr2); \
vf21 = vld1q_f32(filter_ptr2 + 4); \
vf30 = vld1q_f32(filter_ptr3); \
vf31 = vld1q_f32(filter_ptr3 + 4); \
\
/* outch 0 */ \
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); \
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); \
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); \
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); \
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 0); \
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 1); \
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 2); \
\
/* outch 1 */ \
vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0); \
vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1); \
vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2); \
vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3); \
vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 0); \
vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 1); \
vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 2); \
\
/* outch 2 */ \
vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0); \
vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1); \
vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2); \
vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3); \
vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 0); \
vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 1); \
vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 2); \
\
/* outch 3 */ \
vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0); \
vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1); \
vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2); \
vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3); \
vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 0); \
vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 1); \
vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 2);
#define Conv2dArmv8NeonK7x7SnLoadCalc1 \
/* load filter (1 outch x 1 height x 4 width) */ \
float32x4_t vf00, vf01; \
vf00 = vld1q_f32(filter_ptr0); \
vf01 = vld1q_f32(filter_ptr0 + 4); \
\
/* outch 0 */ \
vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); \
vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); \
vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); \
vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); \
vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 0); \
vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 1); \
vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 2);
#define Conv2dArmv7NeonK7x7SnLoadCalc4 \
/* load filter (4 outch x 1 height x 4 width) */ \
float32x4_t vf00, vf01; \
float32x4_t vf10, vf11; \
......@@ -72,7 +138,7 @@ namespace kernels {
vo3 = vmlaq_lane_f32(vo3, vi5, vget_low_f32(vf31), 1); \
vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 0);
#define Conv2dNeonK7x7SnLoadCalc1 \
#define Conv2dArmv7NeonK7x7SnLoadCalc1 \
/* load filter (1 outch x 1 height x 4 width) */ \
float32x4_t vf00, vf01; \
vf00 = vld1q_f32(filter_ptr0); \
......@@ -148,7 +214,7 @@ void Conv2dNeonK7x7S1(const float *input,
filter + (m + 2) * in_channels * 49 + c * 49;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -175,7 +241,11 @@ void Conv2dNeonK7x7S1(const float *input,
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
Conv2dNeonK7x7SnLoadCalc4;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc4;
#else
Conv2dArmv7NeonK7x7SnLoadCalc4;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......@@ -211,7 +281,7 @@ void Conv2dNeonK7x7S1(const float *input,
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -235,7 +305,11 @@ void Conv2dNeonK7x7S1(const float *input,
vi5 = vextq_f32(vi4, vi8, 1);
vi6 = vextq_f32(vi4, vi8, 2);
Conv2dNeonK7x7SnLoadCalc1;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc1;
#else
Conv2dArmv7NeonK7x7SnLoadCalc1;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......@@ -294,7 +368,7 @@ void Conv2dNeonK7x7S2(const float *input,
filter + (m + 2) * in_channels * 49 + c * 49;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -326,7 +400,11 @@ void Conv2dNeonK7x7S2(const float *input,
vi5 = vextq_f32(vi1, vvi1.val[1], 2); // [5.7.9.11]
vi6 = vextq_f32(vi0, vvi1.val[0], 3); // [6.8.10.12]
Conv2dNeonK7x7SnLoadCalc4;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc4;
#else
Conv2dArmv7NeonK7x7SnLoadCalc4;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......@@ -362,7 +440,7 @@ void Conv2dNeonK7x7S2(const float *input,
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -391,7 +469,11 @@ void Conv2dNeonK7x7S2(const float *input,
vi5 = vextq_f32(vi1, vvi1.val[1], 2); // [5.7.9.11]
vi6 = vextq_f32(vi0, vvi1.val[0], 3); // [6.8.10.12]
Conv2dNeonK7x7SnLoadCalc1;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc1;
#else
Conv2dArmv7NeonK7x7SnLoadCalc1;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......@@ -450,7 +532,7 @@ void Conv2dNeonK7x7S3(const float *input,
filter + (m + 2) * in_channels * 49 + c * 49;
const float *filter_ptr3 =
filter + (m + 3) * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -482,7 +564,11 @@ void Conv2dNeonK7x7S3(const float *input,
vi5 = vextq_f32(vi2, vvi1.val[2], 1); // [5.8.11.14]
vi6 = vextq_f32(vi0, vvi1.val[0], 2); // [6.9.12.15]
Conv2dNeonK7x7SnLoadCalc4;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc4;
#else
Conv2dArmv7NeonK7x7SnLoadCalc4;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......@@ -518,7 +604,7 @@ void Conv2dNeonK7x7S3(const float *input,
const float *in_ptr_base =
input + b * in_batch_size + c * in_image_size;
const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#if defined(MACE_ENABLE_NEON)
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input offset
......@@ -547,7 +633,11 @@ void Conv2dNeonK7x7S3(const float *input,
vi5 = vextq_f32(vi2, vvi1.val[2], 1); // [5.8.11.14]
vi6 = vextq_f32(vi0, vvi1.val[0], 2); // [6.9.12.15]
Conv2dNeonK7x7SnLoadCalc1;
#if defined(__aarch64__)
Conv2dArmv8NeonK7x7SnLoadCalc1;
#else
Conv2dArmv7NeonK7x7SnLoadCalc1;
#endif
in_offset += in_width;
filter_ptr0 += 7;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册