From ad0b1a8a0811a5cb8cd29aa28312b5b6e4008c07 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Fri, 27 Apr 2018 11:20:50 +0800 Subject: [PATCH] optimize depthwise conv3x3 s1 s2 armv7 neon --- mace/kernels/arm/depthwise_conv2d_neon_3x3.cc | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc b/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc index d0ba9ce2..489a3ce8 100644 --- a/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc +++ b/mace/kernels/arm/depthwise_conv2d_neon_3x3.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) #include #endif @@ -99,7 +99,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, } } -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) // load filter (1 outch x 3 height x 3 width): vf_outch_height float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr); @@ -172,6 +172,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo00 = vld1q_f32(out_base + out_offset); vo01 = vld1q_f32(out_base + out_offset + out_width); +#if defined(__aarch64__) // outch 0, height 0 vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1); @@ -193,7 +194,29 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0); vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1); vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2); +#else + // outch 0, height 0 + vo00 = vmlaq_lane_f32(vo00, vi00, vget_low_f32(vf00), 0); + vo00 = vmlaq_lane_f32(vo00, vi01, vget_low_f32(vf00), 1); + vo00 = vmlaq_lane_f32(vo00, vi02, vget_high_f32(vf00), 0); + vo00 = vmlaq_lane_f32(vo00, vi10, vget_low_f32(vf01), 0); + vo00 = vmlaq_lane_f32(vo00, vi11, vget_low_f32(vf01), 1); + vo00 = vmlaq_lane_f32(vo00, vi12, vget_high_f32(vf01), 0); + vo00 = vmlaq_lane_f32(vo00, vi20, vget_low_f32(vf02), 0); + vo00 = vmlaq_lane_f32(vo00, vi21, vget_low_f32(vf02), 1); + vo00 = vmlaq_lane_f32(vo00, vi22, vget_high_f32(vf02), 0); + // outch 0, height 1 + vo01 = vmlaq_lane_f32(vo01, vi10, vget_low_f32(vf00), 0); + vo01 = vmlaq_lane_f32(vo01, vi11, vget_low_f32(vf00), 1); + vo01 = vmlaq_lane_f32(vo01, vi12, vget_high_f32(vf00), 0); + vo01 = vmlaq_lane_f32(vo01, vi20, vget_low_f32(vf01), 0); + vo01 = vmlaq_lane_f32(vo01, vi21, vget_low_f32(vf01), 1); + vo01 = vmlaq_lane_f32(vo01, vi22, vget_high_f32(vf01), 0); + vo01 = vmlaq_lane_f32(vo01, vi30, vget_low_f32(vf02), 0); + vo01 = vmlaq_lane_f32(vo01, vi31, vget_low_f32(vf02), 1); + vo01 = vmlaq_lane_f32(vo01, vi32, vget_high_f32(vf02), 0); +#endif vst1q_f32(out_base + out_offset, vo00); vst1q_f32(out_base + out_offset + out_width, vo01); } // w @@ -316,7 +339,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, } } -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#if defined(MACE_ENABLE_NEON) // load filter (1 outch x 3 height x 3 width): vf_outch_height float32x4_t vf00, vf01, vf02; vf00 = vld1q_f32(filter_ptr); @@ -378,6 +401,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, vi21 = vi2.val[1]; vi22 = vextq_f32(vi20, vi2n, 1); +#if defined(__aarch64__) // outch 0, height 0 vo = vfmaq_laneq_f32(vo, vi00, vf00, 0); vo = vfmaq_laneq_f32(vo, vi01, vf00, 1); @@ -388,7 +412,18 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, vo = vfmaq_laneq_f32(vo, vi20, vf02, 0); vo = vfmaq_laneq_f32(vo, vi21, vf02, 1); vo = vfmaq_laneq_f32(vo, vi22, vf02, 2); - +#else + // outch 0, height 0 + vo = vmlaq_lane_f32(vo, vi00, vget_low_f32(vf00), 0); + vo = vmlaq_lane_f32(vo, vi01, vget_low_f32(vf00), 1); + vo = vmlaq_lane_f32(vo, vi02, vget_high_f32(vf00), 0); + vo = vmlaq_lane_f32(vo, vi10, vget_low_f32(vf01), 0); + vo = vmlaq_lane_f32(vo, vi11, vget_low_f32(vf01), 1); + vo = vmlaq_lane_f32(vo, vi12, vget_high_f32(vf01), 0); + vo = vmlaq_lane_f32(vo, vi20, vget_low_f32(vf02), 0); + vo = vmlaq_lane_f32(vo, vi21, vget_low_f32(vf02), 1); + vo = vmlaq_lane_f32(vo, vi22, vget_high_f32(vf02), 0); +#endif vst1q_f32(out_base + out_offset, vo); } // w -- GitLab