提交 47bff05a 编写于 作者: 吴承辉

Merge branch 'armv7' into 'master'

GEMM Neon v7

See merge request !385
...@@ -7,7 +7,7 @@ package( ...@@ -7,7 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled") load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7")
cc_library( cc_library(
name = "kernels", name = "kernels",
...@@ -28,7 +28,7 @@ cc_library( ...@@ -28,7 +28,7 @@ cc_library(
"opencl/*.h", "opencl/*.h",
"arm/*.h", "arm/*.h",
]), ]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]), copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]),
linkopts = if_android(["-lm"]), linkopts = if_android(["-lm"]),
deps = [ deps = [
"//mace/core", "//mace/core",
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/arm/conv_winograd.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/tensor.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) { ...@@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) {
index_t out_height = in_height - 2; index_t out_height = in_height - 2;
index_t out_width = in_width - 2; index_t out_width = in_width - 2;
index_t input_size = batch * in_channels * in_height * out_height; index_t input_size = batch * in_channels * in_height * in_width;
index_t filter_size = 3 * 3 * in_channels * out_channels; index_t filter_size = 3 * 3 * in_channels * out_channels;
index_t output_size = batch * out_channels * out_height * out_width; index_t output_size = batch * out_channels * out_height * out_width;
std::unique_ptr<float[]> input_data(new float[input_size]); Tensor input;
std::unique_ptr<float[]> filter_data(new float[filter_size]); Tensor filter;
std::unique_ptr<float[]> output_data(new float[output_size]); Tensor output;
std::unique_ptr<float[]> output_data_ref(new float[output_size]); Tensor output_ref;
input.Resize({batch, in_channels, in_height, in_width});
filter.Resize({out_channels, in_channels, 3, 3});
output.Resize({batch, out_channels, out_height, out_width});
output_ref.Resize({batch, out_channels, out_height, out_width});
float *input_data = input.mutable_data<float>();
float *filter_data = filter.mutable_data<float>();
float *output_data = output.mutable_data<float>();
float *output_data_ref = output.mutable_data<float>();
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
std::generate(input_data.get(), input_data.get() + input_size, std::generate(input_data, input_data + input_size,
[&gen, &nd] { [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen))); return std::max(-1.0f, std::min(1.0f, nd(gen)));
}); });
std::generate(filter_data.get(), filter_data.get() + filter_size, std::generate(filter_data, filter_data + filter_size,
[&gen, &nd] { [&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen))); return std::max(-1.0f, std::min(1.0f, nd(gen)));
}); });
kernels::ConvRef3x3s1(input_data.get(), kernels::ConvRef3x3s1(input_data,
filter_data.get(), filter_data,
batch, batch,
in_height, in_height,
in_width, in_width,
in_channels, in_channels,
out_channels, out_channels,
output_data_ref.get()); output_data_ref);
kernels::WinoGradConv3x3s1(input_data.get(), kernels::WinoGradConv3x3s1(input_data,
filter_data.get(), filter_data,
batch, batch,
in_height, in_height,
in_width, in_width,
in_channels, in_channels,
out_channels, out_channels,
6, 6,
output_data.get()); output_data);
// test // test
for (index_t i = 0; i < output_size; ++i) { for (index_t i = 0; i < output_size; ++i) {
......
...@@ -5,10 +5,16 @@ ...@@ -5,10 +5,16 @@
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -119,12 +125,11 @@ inline void GemmTile(const float *A, ...@@ -119,12 +125,11 @@ inline void GemmTile(const float *A,
const index_t stride_w, const index_t stride_w,
float *C) { float *C) {
index_t h, w, k; index_t h, w, k;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for (h = 0; h + 7 < height; h += 8) { for (h = 0; h + 7 < height; h += 8) {
for (k = 0; k + 7 < K; k += 8) { for (k = 0; k + 7 < K; k += 8) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#ifdef __clang__ #ifdef __clang__
int nw = width >> 2; int nw = width >> 2;
if (nw > 0) { if (nw > 0) {
...@@ -388,21 +393,132 @@ inline void GemmTile(const float *A, ...@@ -388,21 +393,132 @@ inline void GemmTile(const float *A,
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
} }
#endif #endif // clang
if (w < width) {
#else const float *a_ptr = A + (h * stride_k + k);
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w); const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
} }
#endif }
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
8,
K - k,
width,
stride_k,
stride_w,
c_ptr);
}
}
if (h < height) {
// TODO(liyin): may use Gemm444
const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
height - h,
K,
width,
stride_k,
stride_w,
c_ptr);
}
#else
#if defined(MACE_ENABLE_NEON) // armv7
for (h = 0; h + 3 < height; h += 4) {
for (k = 0; k + 3 < K; k += 4) {
const float *a_ptr = A + (h * stride_k + k);
int nw = width >> 2;
if (nw > 0) {
// load A
float32x2_t a00, a01, a10, a11, a20, a21, a30, a31;
a00 = vld1_f32(a_ptr);
a01 = vld1_f32(a_ptr + 2);
a10 = vld1_f32(a_ptr + 1 * stride_k);
a11 = vld1_f32(a_ptr + 1 * stride_k + 2);
a20 = vld1_f32(a_ptr + 2 * stride_k);
a21 = vld1_f32(a_ptr + 2 * stride_k + 2);
a30 = vld1_f32(a_ptr + 3 * stride_k);
a31 = vld1_f32(a_ptr + 3 * stride_k + 2);
const float *b_ptr0 = B + k * stride_w;
const float *b_ptr1 = B + (k + 1) * stride_w;
const float *b_ptr2 = B + (k + 2) * stride_w;
const float *b_ptr3 = B + (k + 3) * stride_w;
float *c_ptr0 = C + h * stride_w;
float *c_ptr1 = C + (h + 1) * stride_w;
float *c_ptr2 = C + (h + 2) * stride_w;
float *c_ptr3 = C + (h + 3) * stride_w;
// TODO(liyin): asm v7 prefetch and load optimization
while (nw--) {
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
c0 = vld1q_f32(c_ptr0);
b0 = vld1q_f32(b_ptr0);
b1 = vld1q_f32(b_ptr1);
b2 = vld1q_f32(b_ptr2);
b3 = vld1q_f32(b_ptr3);
c0 = vmlaq_lane_f32(c0, b0, a00, 0);
c0 = vmlaq_lane_f32(c0, b1, a00, 1);
c0 = vmlaq_lane_f32(c0, b2, a01, 0);
c0 = vmlaq_lane_f32(c0, b3, a01, 1);
vst1q_f32(c_ptr0, c0);
c0 = vld1q_f32(c_ptr1);
c0 = vmlaq_lane_f32(c0, b0, a10, 0);
c0 = vmlaq_lane_f32(c0, b1, a10, 1);
c0 = vmlaq_lane_f32(c0, b2, a11, 0);
c0 = vmlaq_lane_f32(c0, b3, a11, 1);
vst1q_f32(c_ptr1, c0);
c0 = vld1q_f32(c_ptr2);
c0 = vmlaq_lane_f32(c0, b0, a20, 0);
c0 = vmlaq_lane_f32(c0, b1, a20, 1);
c0 = vmlaq_lane_f32(c0, b2, a21, 0);
c0 = vmlaq_lane_f32(c0, b3, a21, 1);
vst1q_f32(c_ptr2, c0);
c0 = vld1q_f32(c_ptr3);
c0 = vmlaq_lane_f32(c0, b0, a30, 0);
c0 = vmlaq_lane_f32(c0, b1, a30, 1);
c0 = vmlaq_lane_f32(c0, b2, a31, 0);
c0 = vmlaq_lane_f32(c0, b3, a31, 1);
vst1q_f32(c_ptr3, c0);
b_ptr0 += 4;
b_ptr1 += 4;
b_ptr2 += 4;
b_ptr3 += 4;
c_ptr0 += 4;
c_ptr1 += 4;
c_ptr2 += 4;
c_ptr3 += 4;
}
w = (width >> 2) << 2;
}
if (w < width) { if (w < width) {
const float *a_ptr = A + (h * stride_k + k); const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w); const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w); float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); GemmBlock(a_ptr, b_ptr, 4, 4, width - w, stride_k, stride_w, c_ptr);
} }
} }
if (k < K) { if (k < K) {
...@@ -411,7 +527,7 @@ inline void GemmTile(const float *A, ...@@ -411,7 +527,7 @@ inline void GemmTile(const float *A,
float *c_ptr = C + h * stride_w; float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr, GemmBlock(a_ptr,
b_ptr, b_ptr,
8, 4,
K - k, K - k,
width, width,
stride_k, stride_k,
...@@ -420,7 +536,6 @@ inline void GemmTile(const float *A, ...@@ -420,7 +536,6 @@ inline void GemmTile(const float *A,
} }
} }
if (h < height) { if (h < height) {
// TODO(liyin): may use Gemm444
const float *a_ptr = A + (h * stride_k); const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B; const float *b_ptr = B;
float *c_ptr = C + h * stride_w; float *c_ptr = C + h * stride_w;
...@@ -433,6 +548,11 @@ inline void GemmTile(const float *A, ...@@ -433,6 +548,11 @@ inline void GemmTile(const float *A,
stride_w, stride_w,
c_ptr); c_ptr);
} }
#else // cpu
GemmBlock(A, B, height, K, width, stride_k, stride_w, C);
#endif // armv7
#endif // aarch64
} }
} // namespace } // namespace
......
...@@ -7,7 +7,7 @@ package( ...@@ -7,7 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled") load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7")
cc_library( cc_library(
name = "test", name = "test",
...@@ -34,7 +34,7 @@ cc_library( ...@@ -34,7 +34,7 @@ cc_library(
["*.h"], ["*.h"],
exclude = ["ops_test_util.h"], exclude = ["ops_test_util.h"],
), ),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]), copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]),
deps = [ deps = [
"//mace/kernels", "//mace/kernels",
], ],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册