diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index bcc1002e4d2116262c3298a163a6859c6b588608..ceb966ff55d5be1338c25a79b0730c27ecc15d4d 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -2,9 +2,7 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#include #include "mace/kernels/conv_2d.h" -#include "mace/kernels/neon/conv_2d_neon_3x3.h" namespace mace { namespace kernels { @@ -50,6 +48,9 @@ extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape, const float* filter, const float* bias, float* output, const index_t* output_shape); +extern void Conv2dNeonK3x3S1(const float* input, const index_t* input_shape, + const float* filter, const float* bias, + float* output, const index_t* output_shape); template<> void Conv2dFunctor::operator()(const float* input, // NCHW diff --git a/mace/kernels/neon/conv_2d_neon_1x1.cc b/mace/kernels/neon/conv_2d_neon_1x1.cc index b3556c98a5c38ff17e27ce087c47f9b78375d375..86b7b32892abdf2ec651aa2f3b87cd162a469829 100644 --- a/mace/kernels/neon/conv_2d_neon_1x1.cc +++ b/mace/kernels/neon/conv_2d_neon_1x1.cc @@ -3,7 +3,7 @@ // #include -#include "mace/kernels/conv_2d.h" +#include "mace/core/common.h" namespace mace { namespace kernels { diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee194f4081f63c694ea813c1127a344ab23b68a2 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -0,0 +1,217 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/common.h" + +namespace mace { +namespace kernels { + +static const int REGISTER_SIZE = 4; + +void Conv2dNeonK3x3S1(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + + int batch = output_shape[0]; + int channels = output_shape[1]; + int height = output_shape[2]; + int width = output_shape[3]; + + int input_batch = input_shape[0]; + int input_channels = input_shape[1]; + int input_height = input_shape[2]; + int input_width = input_shape[3]; + + int kernel_h = 3; + int kernel_w = 3; + + int height_count = (height >> 1) << 1; + for (int b = 0; b < batch; ++b) { + float* output_ptr_base = output + b * channels * height * width; + for (int oc = 0; oc < channels; ++oc) { + const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; + const float* input_ptr = input + b * input_channels * input_height * input_width; + float* output_ptr = output_ptr_base + oc * height * width; + + std::fill(output_ptr, output_ptr + height * width, bias[oc]); + for (int ic = 0; ic < input_channels; ++ic) { + float32x4_t filter0 = vld1q_f32(filter_ptr); + float32x4_t filter3 = vld1q_f32(filter_ptr+3); + float32x4_t filter6 = vld1q_f32(filter_ptr+6); + + const float* row[REGISTER_SIZE] = { + input_ptr, input_ptr + input_width, + input_ptr + 2 * input_width, input_ptr + 3 * input_width + }; + + float* output_ptr1 = output_ptr; + float* output_ptr2 = output_ptr + width; + + for (int h = 0; h < height_count; h += 2) { + + int count = width >> 2; + int remain_count = width & 3; + + for (; count > 0; --count) { + float32x4_t sum0 = vdupq_n_f32(.0f); + float32x4_t sum1 = vdupq_n_f32(.0f); + float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 + float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567 + float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 + float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); + + float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 + float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567 + float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 + float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); + sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); + sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); + + row0_ext_0 = vld1q_f32(row[2]); //0123 + row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567 + row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 + row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2); + + // second row + sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter0, 0); + sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter0, 1); + sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter0, 2); + + sum1 = vfmaq_laneq_f32(sum1, row0_ext_0, filter3, 0); + sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1); + sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2); + + row1_ext_0 = vld1q_f32(row[3]); //0123 + row1_latter = vld1q_f32(row[3] + REGISTER_SIZE); //4567 + row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 + row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 + + sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0); + sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1); + sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter6, 2); + + float32x4_t output_row0 = vld1q_f32(output_ptr1); + float32x4_t output_row1 = vld1q_f32(output_ptr2); + output_row0 = vaddq_f32(output_row0, sum0); + output_row1 = vaddq_f32(output_row1, sum1); + vst1q_f32(output_ptr1, output_row0); + vst1q_f32(output_ptr2, output_row1); + + output_ptr1 += REGISTER_SIZE; + output_ptr2 += REGISTER_SIZE; + for(int i = 0; i < REGISTER_SIZE; ++i) { + row[i] += REGISTER_SIZE; + } + } + for (; remain_count > 0; --remain_count) { + float32x4_t row0 = vld1q_f32(row[0]); //0123 + float32x4_t row1 = vld1q_f32(row[1]); //0123 + float32x4_t row2 = vld1q_f32(row[2]); //0123 + float32x4_t row3 = vld1q_f32(row[3]); //0123 + + float32x4_t sum = vmulq_f32(row0, filter0); + sum = vmlaq_f32(sum, row1, filter3); + sum = vmlaq_f32(sum, row2, filter6); + sum = vsetq_lane_f32(*output_ptr1, sum, 3); + *output_ptr1 = vaddvq_f32(sum); + + sum = vmulq_f32(row1, filter0); + sum = vmlaq_f32(sum, row2, filter3); + sum = vmlaq_f32(sum, row3, filter6); + sum = vsetq_lane_f32(*output_ptr2, sum, 3); + *output_ptr2 = vaddvq_f32(sum); + + ++output_ptr1; + ++output_ptr2; + for(int i = 0; i < REGISTER_SIZE; ++i) { + row[i] += 1; + } + } + output_ptr1 += width; + output_ptr2 += width; + for(int i = 0; i < REGISTER_SIZE; ++i) { + row[i] += 2 + input_width; + } + } + + if (height != height_count) { + int count = width >> 2; + int remain_count = width & 3; + for(; count > 0; --count) { + float32x4_t sum0 = vdupq_n_f32(.0f); + float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123 + float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567 + float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 + float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2); + + float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123 + float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567 + float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234 + float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0); + sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1); + sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2); + + row0_ext_0 = vld1q_f32(row[2]); //0123 + row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567 + row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234 + row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345 + + sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1); + sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2); + + float32x4_t output_row0 = vld1q_f32(output_ptr1); + output_row0 = vaddq_f32(output_row0, sum0); + vst1q_f32(output_ptr1, output_row0); + output_ptr1 += REGISTER_SIZE; + for(int i = 0; i < 3; ++i) { + row[i] += REGISTER_SIZE; + } + } + for (; remain_count > 0; --remain_count) { + float32x4_t row0 = vld1q_f32(row[0]); //0123 + float32x4_t row1 = vld1q_f32(row[1]); //0123 + float32x4_t row2 = vld1q_f32(row[2]); //0123 + + float32x4_t sum = vmulq_f32(row0, filter0); + sum = vmlaq_f32(sum, row1, filter3); + sum = vmlaq_f32(sum, row2, filter6); + sum = vsetq_lane_f32(*output_ptr1, sum, 3); + *output_ptr1 = vaddvq_f32(sum); + + ++output_ptr1; + for(int i = 0; i < 3; ++i) { + row[i] += 1; + } + } + } + filter_ptr += 9; + input_ptr += input_height * input_width; + } + } + } +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon_3x3.h b/mace/kernels/neon/conv_2d_neon_3x3.h deleted file mode 100644 index 9916e3e03dd6bf4139aa32dbc487c7447119f425..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/conv_2d_neon_3x3.h +++ /dev/null @@ -1,25 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// -#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ -#define MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ - -#include -#include "mace/core/common.h" - -namespace mace { -namespace kernels { - -void Conv2dNeonK3x3S1(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape) { - -} - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ diff --git a/mace/kernels/test/conv_2d_neon_3x3_test.cc b/mace/kernels/test/conv_2d_neon_3x3_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7fc93bf3862c7528714c8e3ec4e2c355d584e3c --- /dev/null +++ b/mace/kernels/test/conv_2d_neon_3x3_test.cc @@ -0,0 +1,82 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gtest/gtest.h" +#include "mace/kernels/conv_2d.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { + +TEST(Conv2dNeon3X3Test, Correctness) { + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution nd(0, 1); + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 16; + index_t channels = 3 + rand() % 100; + index_t height = 10 + rand() % 100; + index_t width = 10 + rand() % 100; + index_t output_channels = 3 + rand() % 100; + + index_t input_size = batch * channels * height * width; + index_t filter_size = output_channels * channels * 3 * 3; + std::vector input(input_size, 0.0); + const index_t input_shape[] = {batch, channels, height, width}; + std::vector filter(filter_size, 0.0); + const index_t filter_shape[] = {output_channels, channels, 3, 3}; + std::vector bias(output_channels, 0.0); + const int dilations[] = {1, 1}; + const int strides[] = {1, 1}; + + // declare output + vector output_shape; + vector padding_size; + kernels::CalcPaddingAndOutputSize(input_shape, filter_shape, dilations, strides, VALID, + &output_shape, &padding_size); + + const index_t output_size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]; + std::unique_ptr output(new float[output_size]); + std::unique_ptr output_neon(new float[output_size]); + + + for (int i = 0; i < input_size; ++i) { + input[i] = nd(gen); + } + for (int i = 0; i < filter_size; ++i) { + filter[i] = nd(gen); + } + for (int i = 0; i < output_channels; ++i) { + bias[i] = nd(gen); + } + + kernels::Conv2dFunctor(strides, padding_size.data(), dilations)( + input.data(), + input_shape, + filter.data(), + filter_shape, + bias.data(), + output.get(), + output_shape.data() + ); + + kernels::Conv2dFunctor(strides, padding_size.data(), dilations)( + input.data(), + input_shape, + filter.data(), + filter_shape, + bias.data(), + output_neon.get(), + output_shape.data() + ); + + + for (index_t i = 0; i < output_size; ++i) { + EXPECT_NEAR(output[i], output_neon[i], 1e-3); + } +} + +} // namespace mace \ No newline at end of file diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 531d81085298fd1e2b92a801cc6a045875177b95..1fe070e2e37e87aeca2305e7bd64dbd790c79641 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -35,7 +35,7 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, net.AddRandomInput("Filter", {output_channels, channels, kernel_h, kernel_w}); net.AddRandomInput("Bias", {output_channels}); - // Worm-up + // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); } @@ -61,5 +61,6 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); } // namespace mace