diff --git a/mace/kernels/arm/conv_2d_neon.h b/mace/kernels/arm/conv_2d_neon.h index 0b02541297f3d1c7172015e6c0afe091e16a4834..3d3c907e97221eb4970a47f5caba9b69b0e13070 100644 --- a/mace/kernels/arm/conv_2d_neon.h +++ b/mace/kernels/arm/conv_2d_neon.h @@ -51,6 +51,17 @@ extern void Conv2dNeonK3x3S2(const float *input, const index_t out_channels, float *output); +extern void Conv2dNeonK5x5S1(const float *input, + const float *filter, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t out_height, + const index_t out_width, + const index_t out_channels, + float *output); + extern void Conv2dNeonK7x7S1(const float *input, const float *filter, const index_t batch, diff --git a/mace/kernels/arm/conv_2d_neon_5x5.cc b/mace/kernels/arm/conv_2d_neon_5x5.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3af5b2a9cd3752084430bc2eda04c9862f17da0 --- /dev/null +++ b/mace/kernels/arm/conv_2d_neon_5x5.cc @@ -0,0 +1,242 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include "mace/kernels/arm/conv_2d_neon.h" + +namespace mace { +namespace kernels { + +#define Conv2dNeonK5x5SnLoadCalc4 \ + /* load filter (4 outch x 1 height x 4 width) */ \ + float32x4_t vf00, vf10, vf20, vf30; \ + float32x2_t vf01, vf11, vf21, vf31; \ + vf00 = vld1q_f32(filter_ptr0); \ + vf01 = vld1_f32(filter_ptr0 + 4); \ + vf10 = vld1q_f32(filter_ptr1); \ + vf11 = vld1_f32(filter_ptr1 + 4); \ + vf20 = vld1q_f32(filter_ptr2); \ + vf21 = vld1_f32(filter_ptr2 + 4); \ + vf30 = vld1q_f32(filter_ptr3); \ + vf31 = vld1_f32(filter_ptr3 + 4); \ + \ + /* outch 0 */ \ + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); \ + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); \ + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); \ + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); \ + vo0 = vmlaq_lane_f32(vo0, vi4, vf01, 0); \ + \ + /* outch 1 */ \ + vo1 = vmlaq_lane_f32(vo1, vi0, vget_low_f32(vf10), 0); \ + vo1 = vmlaq_lane_f32(vo1, vi1, vget_low_f32(vf10), 1); \ + vo1 = vmlaq_lane_f32(vo1, vi2, vget_high_f32(vf10), 0); \ + vo1 = vmlaq_lane_f32(vo1, vi3, vget_high_f32(vf10), 1); \ + vo1 = vmlaq_lane_f32(vo1, vi4, vf11, 0); \ + \ + /* outch 2 */ \ + vo2 = vmlaq_lane_f32(vo2, vi0, vget_low_f32(vf20), 0); \ + vo2 = vmlaq_lane_f32(vo2, vi1, vget_low_f32(vf20), 1); \ + vo2 = vmlaq_lane_f32(vo2, vi2, vget_high_f32(vf20), 0); \ + vo2 = vmlaq_lane_f32(vo2, vi3, vget_high_f32(vf20), 1); \ + vo2 = vmlaq_lane_f32(vo2, vi4, vf21, 0); \ + \ + /* outch 3 */ \ + vo3 = vmlaq_lane_f32(vo3, vi0, vget_low_f32(vf30), 0); \ + vo3 = vmlaq_lane_f32(vo3, vi1, vget_low_f32(vf30), 1); \ + vo3 = vmlaq_lane_f32(vo3, vi2, vget_high_f32(vf30), 0); \ + vo3 = vmlaq_lane_f32(vo3, vi3, vget_high_f32(vf30), 1); \ + vo3 = vmlaq_lane_f32(vo3, vi4, vf31, 0); + +#define Conv2dNeonK5x5SnLoadCalc1 \ + /* load filter (1 outch x 1 height x 4 width) */ \ + float32x4_t vf00; \ + float32x2_t vf01; \ + vf00 = vld1q_f32(filter_ptr0); \ + vf01 = vld1_f32(filter_ptr0 + 4); \ + \ + /* outch 0 */ \ + vo0 = vmlaq_lane_f32(vo0, vi0, vget_low_f32(vf00), 0); \ + vo0 = vmlaq_lane_f32(vo0, vi1, vget_low_f32(vf00), 1); \ + vo0 = vmlaq_lane_f32(vo0, vi2, vget_high_f32(vf00), 0); \ + vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); \ + vo0 = vmlaq_lane_f32(vo0, vi4, vf01, 0); + +inline void Conv2dCPUK5x5Calc(const float *in_ptr_base, + const float *filter_ptr0, + const index_t in_width, + const index_t in_channels, + const index_t out_height, + const index_t out_width, + const index_t out_image_size, + float *out_ptr0_base, + const index_t io, + const int stride) { + for (index_t ih = 0; ih < out_height; ++ih) { + for (index_t iw = 0; iw < out_width; ++iw) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + out_ptr0_base[io * out_image_size + ih * out_width + iw] += + in_ptr_base[(ih * stride + i) * in_width + (iw * stride + j)] * + filter_ptr0[io * in_channels * 25 + i * 5 + j]; + } + } + } + } +} + + +// Ho = 1, Wo = 4, Co = 4 +void Conv2dNeonK5x5S1(const float *input, + const float *filter, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t out_height, + const index_t out_width, + const index_t out_channels, + float *output) { + const index_t in_image_size = in_height * in_width; + const index_t out_image_size = out_height * out_width; + const index_t in_batch_size = in_channels * in_image_size; + const index_t out_batch_size = out_channels * out_image_size; + +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < batch; ++b) { + for (index_t m = 0; m < out_channels; m += 4) { + if (m + 3 < out_channels) { + float *out_ptr0_base = output + b * out_batch_size + m * out_image_size; + float *out_ptr1_base = + output + b * out_batch_size + (m + 1) * out_image_size; + float *out_ptr2_base = + output + b * out_batch_size + (m + 2) * out_image_size; + float *out_ptr3_base = + output + b * out_batch_size + (m + 3) * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + m * in_channels * 25 + c * 25; + const float *filter_ptr1 = + filter + (m + 1) * in_channels * 25 + c * 25; + const float *filter_ptr2 = + filter + (m + 2) * in_channels * 25 + c * 25; + const float *filter_ptr3 = + filter + (m + 3) * in_channels * 25 + c * 25; +#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // input offset + index_t in_offset = h * in_width + w; + // output (4 outch x 1 height x 4 width): vo_outch_height + float32x4_t vo0, vo1, vo2, vo3; + // load output + index_t out_offset = h * out_width + w; + vo0 = vld1q_f32(out_ptr0_base + out_offset); + vo1 = vld1q_f32(out_ptr1_base + out_offset); + vo2 = vld1q_f32(out_ptr2_base + out_offset); + vo3 = vld1q_f32(out_ptr3_base + out_offset); + for (index_t r = 0; r < 5; ++r) { + // input (3 slide) + float32x4_t vi0, vi1, vi2, vi3, vi4; + // load input + vi0 = vld1q_f32(in_ptr_base + in_offset); + vi4 = vld1q_f32(in_ptr_base + in_offset + 4); + vi1 = vextq_f32(vi0, vi4, 1); + vi2 = vextq_f32(vi0, vi4, 2); + vi3 = vextq_f32(vi0, vi4, 3); + + Conv2dNeonK5x5SnLoadCalc4; + + in_offset += in_width; + filter_ptr0 += 5; + filter_ptr1 += 5; + filter_ptr2 += 5; + filter_ptr3 += 5; + } // r + + vst1q_f32(out_ptr0_base + out_offset, vo0); + vst1q_f32(out_ptr1_base + out_offset, vo1); + vst1q_f32(out_ptr2_base + out_offset, vo2); + vst1q_f32(out_ptr3_base + out_offset, vo3); + + filter_ptr0 -= 25; + filter_ptr1 -= 25; + filter_ptr2 -= 25; + filter_ptr3 -= 25; + } // w + } // h +#else + for (index_t io = 0; io < 4; ++io) { + Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels, + out_height, out_width, out_image_size, + out_ptr0_base, io, 1); + } // for +#endif + } // c + } else { + for (index_t mm = m; mm < out_channels; ++mm) { + float *out_ptr0_base = + output + b * out_batch_size + mm * out_image_size; + for (index_t c = 0; c < in_channels; ++c) { + const float *in_ptr_base = + input + b * in_batch_size + c * in_image_size; + const float *filter_ptr0 = filter + mm * in_channels * 25 + c * 25; +#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w + 3 < out_width; w += 4) { + // input offset + index_t in_offset = h * in_width + w; + // output (1 outch x 1 height x 4 width): vo_outch_height + float32x4_t vo0; + // load output + index_t out_offset = h * out_width + w; + vo0 = vld1q_f32(out_ptr0_base + out_offset); + for (index_t r = 0; r < 5; ++r) { + // input (3 slide) + float32x4_t vi0, vi1, vi2, vi3, vi4; + // load input + vi0 = vld1q_f32(in_ptr_base + in_offset); + vi4 = vld1q_f32(in_ptr_base + in_offset + 4); + vi1 = vextq_f32(vi0, vi4, 1); + vi2 = vextq_f32(vi0, vi4, 2); + vi3 = vextq_f32(vi0, vi4, 3); + + Conv2dNeonK5x5SnLoadCalc1; + + in_offset += in_width; + filter_ptr0 += 5; + } // r + + vst1q_f32(out_ptr0_base + out_offset, vo0); + filter_ptr0 -= 25; + } // w + } // h +#else + Conv2dCPUK5x5Calc(in_ptr_base, filter_ptr0, in_width, in_channels, + out_height, out_width, out_image_size, + out_ptr0_base, 0, 1); +#endif + } // c + } // mm + } // if + } // m + } // b +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/arm/conv_2d_neon_7x7.cc b/mace/kernels/arm/conv_2d_neon_7x7.cc index 7a8acaa8e8a22d297168ecd5d0313b6cb095b78a..ed40e5c9fd5466fe7a1a37c26de1f4fea3d5c0ab 100644 --- a/mace/kernels/arm/conv_2d_neon_7x7.cc +++ b/mace/kernels/arm/conv_2d_neon_7x7.cc @@ -16,7 +16,7 @@ #include #endif -#include "mace/core/types.h" +#include "mace/kernels/arm/conv_2d_neon.h" namespace mace { namespace kernels { @@ -88,15 +88,15 @@ namespace kernels { vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 0); inline void Conv2dCPUK7x7Calc(const float *in_ptr_base, - const float *filter_ptr0, - const index_t in_width, - const index_t in_channels, - const index_t out_height, - const index_t out_width, - const index_t out_image_size, - float *out_ptr0_base, - const index_t io, - const int stride) { + const float *filter_ptr0, + const index_t in_width, + const index_t in_channels, + const index_t out_height, + const index_t out_width, + const index_t out_image_size, + float *out_ptr0_base, + const index_t io, + const int stride) { for (index_t ih = 0; ih < out_height; ++ih) { for (index_t iw = 0; iw < out_width; ++iw) { for (int i = 0; i < 7; ++i) { diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index c9ac85613c1299ed661d05171a53aef0ad8c8241..acecb578ce0ae2f021c85cffc1bb56d19a63aa18 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -224,6 +224,8 @@ struct Conv2dFunctor : Conv2dFunctorBase { && stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1; bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; + bool use_neon_5x5_s1 = filter_h == 5 && filter_w == 5 + && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7 @@ -294,6 +296,18 @@ struct Conv2dFunctor : Conv2dFunctorBase { if (extra_input_width != padded_input_width) { pad_right += (extra_input_width - padded_input_width); } + } else if (use_neon_5x5_s1) { + extra_output_height = height; + extra_input_height = + std::max(padded_input_height, extra_output_height + 4); + extra_output_width = RoundUp(width, 4); + extra_input_width = std::max(padded_input_width, extra_output_width + 4); + if (extra_input_height != padded_input_height) { + pad_bottom += (extra_input_height - padded_input_height); + } + if (extra_input_width != padded_input_width) { + pad_right += (extra_input_width - padded_input_width); + } } else if (use_neon_7x7_s1) { extra_output_height = height; extra_input_height = @@ -457,6 +471,19 @@ struct Conv2dFunctor : Conv2dFunctorBase { channels, pad_output); }; + } else if (use_neon_5x5_s1) { + conv_func = [=](const float *pad_input, float *pad_output) { + Conv2dNeonK5x5S1(pad_input, + filter_data, + batch, + extra_input_height, + extra_input_width, + input_channels, + extra_output_height, + extra_output_width, + channels, + pad_output); + }; } else if (use_neon_7x7_s1) { conv_func = [=](const float *pad_input, float *pad_output) { Conv2dNeonK7x7S1(pad_input, diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 049e38d201cd8e7bed75b63cfc76a84e80d5d22f..9880ca72a53b57354bad9fdc39602599ef49f914 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -516,7 +516,7 @@ void TestComplexConvNxNS12(const std::vector &shape, *net.GetOutput("OPENCLOutput"), 1e-4, 1e-4); }; - for (int kernel_size : {1, 3, 7}) { + for (int kernel_size : {1, 3, 5, 7}) { func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, SAME); }