diff --git a/mace/kernels/arm/conv_2d_neon_7x7.cc b/mace/kernels/arm/conv_2d_neon_7x7.cc index ed40e5c9fd5466fe7a1a37c26de1f4fea3d5c0ab..dcb9fe20dfd9dc27d56dc714c8135f00e6c72c80 100644 --- a/mace/kernels/arm/conv_2d_neon_7x7.cc +++ b/mace/kernels/arm/conv_2d_neon_7x7.cc @@ -21,7 +21,73 @@ namespace mace { namespace kernels { -#define Conv2dNeonK7x7SnLoadCalc4 \ +#define Conv2dArmv8NeonK7x7SnLoadCalc4 \ + /* load filter (4 outch x 1 height x 4 width) */ \ + float32x4_t vf00, vf01; \ + float32x4_t vf10, vf11; \ + float32x4_t vf20, vf21; \ + float32x4_t vf30, vf31; \ + vf00 = vld1q_f32(filter_ptr0); \ + vf01 = vld1q_f32(filter_ptr0 + 4); \ + vf10 = vld1q_f32(filter_ptr1); \ + vf11 = vld1q_f32(filter_ptr1 + 4); \ + vf20 = vld1q_f32(filter_ptr2); \ + vf21 = vld1q_f32(filter_ptr2 + 4); \ + vf30 = vld1q_f32(filter_ptr3); \ + vf31 = vld1q_f32(filter_ptr3 + 4); \ + \ + /* outch 0 */ \ + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); \ + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); \ + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); \ + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); \ + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 0); \ + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 1); \ + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 2); \ + \ + /* outch 1 */ \ + vo1 = vfmaq_laneq_f32(vo1, vi0, vf10, 0); \ + vo1 = vfmaq_laneq_f32(vo1, vi1, vf10, 1); \ + vo1 = vfmaq_laneq_f32(vo1, vi2, vf10, 2); \ + vo1 = vfmaq_laneq_f32(vo1, vi3, vf10, 3); \ + vo1 = vfmaq_laneq_f32(vo1, vi4, vf11, 0); \ + vo1 = vfmaq_laneq_f32(vo1, vi5, vf11, 1); \ + vo1 = vfmaq_laneq_f32(vo1, vi6, vf11, 2); \ + \ + /* outch 2 */ \ + vo2 = vfmaq_laneq_f32(vo2, vi0, vf20, 0); \ + vo2 = vfmaq_laneq_f32(vo2, vi1, vf20, 1); \ + vo2 = vfmaq_laneq_f32(vo2, vi2, vf20, 2); \ + vo2 = vfmaq_laneq_f32(vo2, vi3, vf20, 3); \ + vo2 = vfmaq_laneq_f32(vo2, vi4, vf21, 0); \ + vo2 = vfmaq_laneq_f32(vo2, vi5, vf21, 1); \ + vo2 = vfmaq_laneq_f32(vo2, vi6, vf21, 2); \ + \ + /* outch 3 */ \ + vo3 = vfmaq_laneq_f32(vo3, vi0, vf30, 0); \ + vo3 = vfmaq_laneq_f32(vo3, vi1, vf30, 1); \ + vo3 = vfmaq_laneq_f32(vo3, vi2, vf30, 2); \ + vo3 = vfmaq_laneq_f32(vo3, vi3, vf30, 3); \ + vo3 = vfmaq_laneq_f32(vo3, vi4, vf31, 0); \ + vo3 = vfmaq_laneq_f32(vo3, vi5, vf31, 1); \ + vo3 = vfmaq_laneq_f32(vo3, vi6, vf31, 2); + +#define Conv2dArmv8NeonK7x7SnLoadCalc1 \ + /* load filter (1 outch x 1 height x 4 width) */ \ + float32x4_t vf00, vf01; \ + vf00 = vld1q_f32(filter_ptr0); \ + vf01 = vld1q_f32(filter_ptr0 + 4); \ + \ + /* outch 0 */ \ + vo0 = vfmaq_laneq_f32(vo0, vi0, vf00, 0); \ + vo0 = vfmaq_laneq_f32(vo0, vi1, vf00, 1); \ + vo0 = vfmaq_laneq_f32(vo0, vi2, vf00, 2); \ + vo0 = vfmaq_laneq_f32(vo0, vi3, vf00, 3); \ + vo0 = vfmaq_laneq_f32(vo0, vi4, vf01, 0); \ + vo0 = vfmaq_laneq_f32(vo0, vi5, vf01, 1); \ + vo0 = vfmaq_laneq_f32(vo0, vi6, vf01, 2); + +#define Conv2dArmv7NeonK7x7SnLoadCalc4 \ /* load filter (4 outch x 1 height x 4 width) */ \ float32x4_t vf00, vf01; \ float32x4_t vf10, vf11; \ @@ -72,7 +138,7 @@ namespace kernels { vo3 = vmlaq_lane_f32(vo3, vi5, vget_low_f32(vf31), 1); \ vo3 = vmlaq_lane_f32(vo3, vi6, vget_high_f32(vf31), 0); -#define Conv2dNeonK7x7SnLoadCalc1 \ +#define Conv2dArmv7NeonK7x7SnLoadCalc1 \ /* load filter (1 outch x 1 height x 4 width) */ \ float32x4_t vf00, vf01; \ vf00 = vld1q_f32(filter_ptr0); \ @@ -148,7 +214,7 @@ void Conv2dNeonK7x7S1(const float *input, filter + (m + 2) * in_channels * 49 + c * 49; const float *filter_ptr3 = filter + (m + 3) * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -175,7 +241,11 @@ void Conv2dNeonK7x7S1(const float *input, vi5 = vextq_f32(vi4, vi8, 1); vi6 = vextq_f32(vi4, vi8, 2); - Conv2dNeonK7x7SnLoadCalc4; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc4; +#else + Conv2dArmv7NeonK7x7SnLoadCalc4; +#endif in_offset += in_width; filter_ptr0 += 7; @@ -211,7 +281,7 @@ void Conv2dNeonK7x7S1(const float *input, const float *in_ptr_base = input + b * in_batch_size + c * in_image_size; const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -235,7 +305,11 @@ void Conv2dNeonK7x7S1(const float *input, vi5 = vextq_f32(vi4, vi8, 1); vi6 = vextq_f32(vi4, vi8, 2); - Conv2dNeonK7x7SnLoadCalc1; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc1; +#else + Conv2dArmv7NeonK7x7SnLoadCalc1; +#endif in_offset += in_width; filter_ptr0 += 7; @@ -294,7 +368,7 @@ void Conv2dNeonK7x7S2(const float *input, filter + (m + 2) * in_channels * 49 + c * 49; const float *filter_ptr3 = filter + (m + 3) * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -326,7 +400,11 @@ void Conv2dNeonK7x7S2(const float *input, vi5 = vextq_f32(vi1, vvi1.val[1], 2); // [5.7.9.11] vi6 = vextq_f32(vi0, vvi1.val[0], 3); // [6.8.10.12] - Conv2dNeonK7x7SnLoadCalc4; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc4; +#else + Conv2dArmv7NeonK7x7SnLoadCalc4; +#endif in_offset += in_width; filter_ptr0 += 7; @@ -362,7 +440,7 @@ void Conv2dNeonK7x7S2(const float *input, const float *in_ptr_base = input + b * in_batch_size + c * in_image_size; const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -391,7 +469,11 @@ void Conv2dNeonK7x7S2(const float *input, vi5 = vextq_f32(vi1, vvi1.val[1], 2); // [5.7.9.11] vi6 = vextq_f32(vi0, vvi1.val[0], 3); // [6.8.10.12] - Conv2dNeonK7x7SnLoadCalc1; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc1; +#else + Conv2dArmv7NeonK7x7SnLoadCalc1; +#endif in_offset += in_width; filter_ptr0 += 7; @@ -450,7 +532,7 @@ void Conv2dNeonK7x7S3(const float *input, filter + (m + 2) * in_channels * 49 + c * 49; const float *filter_ptr3 = filter + (m + 3) * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -482,7 +564,11 @@ void Conv2dNeonK7x7S3(const float *input, vi5 = vextq_f32(vi2, vvi1.val[2], 1); // [5.8.11.14] vi6 = vextq_f32(vi0, vvi1.val[0], 2); // [6.9.12.15] - Conv2dNeonK7x7SnLoadCalc4; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc4; +#else + Conv2dArmv7NeonK7x7SnLoadCalc4; +#endif in_offset += in_width; filter_ptr0 += 7; @@ -518,7 +604,7 @@ void Conv2dNeonK7x7S3(const float *input, const float *in_ptr_base = input + b * in_batch_size + c * in_image_size; const float *filter_ptr0 = filter + mm * in_channels * 49 + c * 49; -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) for (index_t h = 0; h < out_height; ++h) { for (index_t w = 0; w + 3 < out_width; w += 4) { // input offset @@ -547,7 +633,11 @@ void Conv2dNeonK7x7S3(const float *input, vi5 = vextq_f32(vi2, vvi1.val[2], 1); // [5.8.11.14] vi6 = vextq_f32(vi0, vvi1.val[0], 2); // [6.9.12.15] - Conv2dNeonK7x7SnLoadCalc1; +#if defined(__aarch64__) + Conv2dArmv8NeonK7x7SnLoadCalc1; +#else + Conv2dArmv7NeonK7x7SnLoadCalc1; +#endif in_offset += in_width; filter_ptr0 += 7;