From 81a2ab7cf55da6c3fbb206b54b6fcf73328fd271 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Mon, 16 Apr 2018 15:51:42 +0800 Subject: [PATCH] GEMM Neon v7 --- mace/kernels/BUILD | 4 +- mace/kernels/arm/conv_winograd_test.cc | 37 ++++--- mace/kernels/gemm.cc | 144 ++++++++++++++++++++++--- mace/ops/BUILD | 4 +- 4 files changed, 160 insertions(+), 29 deletions(-) diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index c8ecb5b8..1551eb47 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -7,7 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled") +load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7") cc_library( name = "kernels", @@ -28,7 +28,7 @@ cc_library( "opencl/*.h", "arm/*.h", ]), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]), + copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]), linkopts = if_android(["-lm"]), deps = [ "//mace/core", diff --git a/mace/kernels/arm/conv_winograd_test.cc b/mace/kernels/arm/conv_winograd_test.cc index 04a9895a..9dcc10e2 100644 --- a/mace/kernels/arm/conv_winograd_test.cc +++ b/mace/kernels/arm/conv_winograd_test.cc @@ -9,6 +9,7 @@ #include "mace/kernels/arm/conv_winograd.h" #include "mace/core/types.h" +#include "mace/core/tensor.h" namespace mace { namespace kernels { @@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) { index_t out_height = in_height - 2; index_t out_width = in_width - 2; - index_t input_size = batch * in_channels * in_height * out_height; + index_t input_size = batch * in_channels * in_height * in_width; index_t filter_size = 3 * 3 * in_channels * out_channels; index_t output_size = batch * out_channels * out_height * out_width; - std::unique_ptr input_data(new float[input_size]); - std::unique_ptr filter_data(new float[filter_size]); - std::unique_ptr output_data(new float[output_size]); - std::unique_ptr output_data_ref(new float[output_size]); + Tensor input; + Tensor filter; + Tensor output; + Tensor output_ref; + + input.Resize({batch, in_channels, in_height, in_width}); + filter.Resize({out_channels, in_channels, 3, 3}); + output.Resize({batch, out_channels, out_height, out_width}); + output_ref.Resize({batch, out_channels, out_height, out_width}); + + float *input_data = input.mutable_data(); + float *filter_data = filter.mutable_data(); + float *output_data = output.mutable_data(); + float *output_data_ref = output.mutable_data(); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); - std::generate(input_data.get(), input_data.get() + input_size, + std::generate(input_data, input_data + input_size, [&gen, &nd] { return std::max(-1.0f, std::min(1.0f, nd(gen))); }); - std::generate(filter_data.get(), filter_data.get() + filter_size, + std::generate(filter_data, filter_data + filter_size, [&gen, &nd] { return std::max(-1.0f, std::min(1.0f, nd(gen))); }); - kernels::ConvRef3x3s1(input_data.get(), - filter_data.get(), + kernels::ConvRef3x3s1(input_data, + filter_data, batch, in_height, in_width, in_channels, out_channels, - output_data_ref.get()); + output_data_ref); - kernels::WinoGradConv3x3s1(input_data.get(), - filter_data.get(), + kernels::WinoGradConv3x3s1(input_data, + filter_data, batch, in_height, in_width, in_channels, out_channels, 6, - output_data.get()); + output_data); // test for (index_t i = 0; i < output_size; ++i) { diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index eb5e104d..aea63c87 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -5,10 +5,16 @@ #include #include +#if defined(MACE_ENABLE_NEON) +#include +#endif + #include "mace/kernels/gemm.h" #include "mace/utils/utils.h" #include "mace/utils/logging.h" + + namespace mace { namespace kernels { @@ -119,12 +125,11 @@ inline void GemmTile(const float *A, const index_t stride_w, float *C) { index_t h, w, k; + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) for (h = 0; h + 7 < height; h += 8) { for (k = 0; k + 7 < K; k += 8) { const float *a_ptr = A + (h * stride_k + k); - -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) - #ifdef __clang__ int nw = width >> 2; if (nw > 0) { @@ -388,21 +393,132 @@ inline void GemmTile(const float *A, float *c_ptr = C + (h * stride_w + w); Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); } -#endif - -#else - for (w = 0; w + 3 < width; w += 4) { +#endif // clang + if (w < width) { + const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); } -#endif + } + if (k < K) { + const float *a_ptr = A + (h * stride_k + k); + const float *b_ptr = B + k * stride_w; + float *c_ptr = C + h * stride_w; + GemmBlock(a_ptr, + b_ptr, + 8, + K - k, + width, + stride_k, + stride_w, + c_ptr); + } + } + if (h < height) { + // TODO(liyin): may use Gemm444 + const float *a_ptr = A + (h * stride_k); + const float *b_ptr = B; + float *c_ptr = C + h * stride_w; + GemmBlock(a_ptr, + b_ptr, + height - h, + K, + width, + stride_k, + stride_w, + c_ptr); + } +#else + +#if defined(MACE_ENABLE_NEON) // armv7 + for (h = 0; h + 3 < height; h += 4) { + for (k = 0; k + 3 < K; k += 4) { + const float *a_ptr = A + (h * stride_k + k); + int nw = width >> 2; + if (nw > 0) { + // load A + float32x2_t a00, a01, a10, a11, a20, a21, a30, a31; + a00 = vld1_f32(a_ptr); + a01 = vld1_f32(a_ptr + 2); + a10 = vld1_f32(a_ptr + 1 * stride_k); + a11 = vld1_f32(a_ptr + 1 * stride_k + 2); + a20 = vld1_f32(a_ptr + 2 * stride_k); + a21 = vld1_f32(a_ptr + 2 * stride_k + 2); + a30 = vld1_f32(a_ptr + 3 * stride_k); + a31 = vld1_f32(a_ptr + 3 * stride_k + 2); + + const float *b_ptr0 = B + k * stride_w; + const float *b_ptr1 = B + (k + 1) * stride_w; + const float *b_ptr2 = B + (k + 2) * stride_w; + const float *b_ptr3 = B + (k + 3) * stride_w; + + float *c_ptr0 = C + h * stride_w; + float *c_ptr1 = C + (h + 1) * stride_w; + float *c_ptr2 = C + (h + 2) * stride_w; + float *c_ptr3 = C + (h + 3) * stride_w; + + // TODO(liyin): asm v7 prefetch and load optimization + while (nw--) { + float32x4_t b0, b1, b2, b3; + float32x4_t c0; + + c0 = vld1q_f32(c_ptr0); + + b0 = vld1q_f32(b_ptr0); + b1 = vld1q_f32(b_ptr1); + b2 = vld1q_f32(b_ptr2); + b3 = vld1q_f32(b_ptr3); + + c0 = vmlaq_lane_f32(c0, b0, a00, 0); + c0 = vmlaq_lane_f32(c0, b1, a00, 1); + c0 = vmlaq_lane_f32(c0, b2, a01, 0); + c0 = vmlaq_lane_f32(c0, b3, a01, 1); + + vst1q_f32(c_ptr0, c0); + c0 = vld1q_f32(c_ptr1); + + c0 = vmlaq_lane_f32(c0, b0, a10, 0); + c0 = vmlaq_lane_f32(c0, b1, a10, 1); + c0 = vmlaq_lane_f32(c0, b2, a11, 0); + c0 = vmlaq_lane_f32(c0, b3, a11, 1); + + vst1q_f32(c_ptr1, c0); + c0 = vld1q_f32(c_ptr2); + + c0 = vmlaq_lane_f32(c0, b0, a20, 0); + c0 = vmlaq_lane_f32(c0, b1, a20, 1); + c0 = vmlaq_lane_f32(c0, b2, a21, 0); + c0 = vmlaq_lane_f32(c0, b3, a21, 1); + + vst1q_f32(c_ptr2, c0); + c0 = vld1q_f32(c_ptr3); + c0 = vmlaq_lane_f32(c0, b0, a30, 0); + c0 = vmlaq_lane_f32(c0, b1, a30, 1); + c0 = vmlaq_lane_f32(c0, b2, a31, 0); + c0 = vmlaq_lane_f32(c0, b3, a31, 1); + + vst1q_f32(c_ptr3, c0); + + b_ptr0 += 4; + b_ptr1 += 4; + b_ptr2 += 4; + b_ptr3 += 4; + + c_ptr0 += 4; + c_ptr1 += 4; + c_ptr2 += 4; + c_ptr3 += 4; + } + + w = (width >> 2) << 2; + } if (w < width) { const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 4, 4, width - w, stride_k, stride_w, c_ptr); } } if (k < K) { @@ -411,7 +527,7 @@ inline void GemmTile(const float *A, float *c_ptr = C + h * stride_w; GemmBlock(a_ptr, b_ptr, - 8, + 4, K - k, width, stride_k, @@ -420,7 +536,6 @@ inline void GemmTile(const float *A, } } if (h < height) { - // TODO(liyin): may use Gemm444 const float *a_ptr = A + (h * stride_k); const float *b_ptr = B; float *c_ptr = C + h * stride_w; @@ -433,6 +548,11 @@ inline void GemmTile(const float *A, stride_w, c_ptr); } +#else // cpu + GemmBlock(A, B, height, K, width, stride_k, stride_w, C); +#endif // armv7 + +#endif // aarch64 } } // namespace diff --git a/mace/ops/BUILD b/mace/ops/BUILD index dbe6e5e2..c7a9b95b 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -7,7 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled") +load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7") cc_library( name = "test", @@ -34,7 +34,7 @@ cc_library( ["*.h"], exclude = ["ops_test_util.h"], ), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]), + copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]), deps = [ "//mace/kernels", ], -- GitLab