提交 81a2ab7c 编写于 作者: 李寅

GEMM Neon v7

上级 680f8b42
......@@ -7,7 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled")
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7")
cc_library(
name = "kernels",
......@@ -28,7 +28,7 @@ cc_library(
"opencl/*.h",
"arm/*.h",
]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]),
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
......
......@@ -9,6 +9,7 @@
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/core/types.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
......@@ -22,45 +23,55 @@ TEST(ConvWinogradTest, winograd) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t input_size = batch * in_channels * in_height * out_height;
index_t input_size = batch * in_channels * in_height * in_width;
index_t filter_size = 3 * 3 * in_channels * out_channels;
index_t output_size = batch * out_channels * out_height * out_width;
std::unique_ptr<float[]> input_data(new float[input_size]);
std::unique_ptr<float[]> filter_data(new float[filter_size]);
std::unique_ptr<float[]> output_data(new float[output_size]);
std::unique_ptr<float[]> output_data_ref(new float[output_size]);
Tensor input;
Tensor filter;
Tensor output;
Tensor output_ref;
input.Resize({batch, in_channels, in_height, in_width});
filter.Resize({out_channels, in_channels, 3, 3});
output.Resize({batch, out_channels, out_height, out_width});
output_ref.Resize({batch, out_channels, out_height, out_width});
float *input_data = input.mutable_data<float>();
float *filter_data = filter.mutable_data<float>();
float *output_data = output.mutable_data<float>();
float *output_data_ref = output.mutable_data<float>();
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(input_data.get(), input_data.get() + input_size,
std::generate(input_data, input_data + input_size,
[&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
std::generate(filter_data.get(), filter_data.get() + filter_size,
std::generate(filter_data, filter_data + filter_size,
[&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
kernels::ConvRef3x3s1(input_data.get(),
filter_data.get(),
kernels::ConvRef3x3s1(input_data,
filter_data,
batch,
in_height,
in_width,
in_channels,
out_channels,
output_data_ref.get());
output_data_ref);
kernels::WinoGradConv3x3s1(input_data.get(),
filter_data.get(),
kernels::WinoGradConv3x3s1(input_data,
filter_data,
batch,
in_height,
in_width,
in_channels,
out_channels,
6,
output_data.get());
output_data);
// test
for (index_t i = 0; i < output_size; ++i) {
......
......@@ -5,10 +5,16 @@
#include <math.h>
#include <algorithm>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
namespace mace {
namespace kernels {
......@@ -119,12 +125,11 @@ inline void GemmTile(const float *A,
const index_t stride_w,
float *C) {
index_t h, w, k;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
for (h = 0; h + 7 < height; h += 8) {
for (k = 0; k + 7 < K; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#ifdef __clang__
int nw = width >> 2;
if (nw > 0) {
......@@ -388,21 +393,132 @@ inline void GemmTile(const float *A,
float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
}
#endif
#else
for (w = 0; w + 3 < width; w += 4) {
#endif // clang
if (w < width) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
}
#endif
}
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
8,
K - k,
width,
stride_k,
stride_w,
c_ptr);
}
}
if (h < height) {
// TODO(liyin): may use Gemm444
const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
height - h,
K,
width,
stride_k,
stride_w,
c_ptr);
}
#else
#if defined(MACE_ENABLE_NEON) // armv7
for (h = 0; h + 3 < height; h += 4) {
for (k = 0; k + 3 < K; k += 4) {
const float *a_ptr = A + (h * stride_k + k);
int nw = width >> 2;
if (nw > 0) {
// load A
float32x2_t a00, a01, a10, a11, a20, a21, a30, a31;
a00 = vld1_f32(a_ptr);
a01 = vld1_f32(a_ptr + 2);
a10 = vld1_f32(a_ptr + 1 * stride_k);
a11 = vld1_f32(a_ptr + 1 * stride_k + 2);
a20 = vld1_f32(a_ptr + 2 * stride_k);
a21 = vld1_f32(a_ptr + 2 * stride_k + 2);
a30 = vld1_f32(a_ptr + 3 * stride_k);
a31 = vld1_f32(a_ptr + 3 * stride_k + 2);
const float *b_ptr0 = B + k * stride_w;
const float *b_ptr1 = B + (k + 1) * stride_w;
const float *b_ptr2 = B + (k + 2) * stride_w;
const float *b_ptr3 = B + (k + 3) * stride_w;
float *c_ptr0 = C + h * stride_w;
float *c_ptr1 = C + (h + 1) * stride_w;
float *c_ptr2 = C + (h + 2) * stride_w;
float *c_ptr3 = C + (h + 3) * stride_w;
// TODO(liyin): asm v7 prefetch and load optimization
while (nw--) {
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
c0 = vld1q_f32(c_ptr0);
b0 = vld1q_f32(b_ptr0);
b1 = vld1q_f32(b_ptr1);
b2 = vld1q_f32(b_ptr2);
b3 = vld1q_f32(b_ptr3);
c0 = vmlaq_lane_f32(c0, b0, a00, 0);
c0 = vmlaq_lane_f32(c0, b1, a00, 1);
c0 = vmlaq_lane_f32(c0, b2, a01, 0);
c0 = vmlaq_lane_f32(c0, b3, a01, 1);
vst1q_f32(c_ptr0, c0);
c0 = vld1q_f32(c_ptr1);
c0 = vmlaq_lane_f32(c0, b0, a10, 0);
c0 = vmlaq_lane_f32(c0, b1, a10, 1);
c0 = vmlaq_lane_f32(c0, b2, a11, 0);
c0 = vmlaq_lane_f32(c0, b3, a11, 1);
vst1q_f32(c_ptr1, c0);
c0 = vld1q_f32(c_ptr2);
c0 = vmlaq_lane_f32(c0, b0, a20, 0);
c0 = vmlaq_lane_f32(c0, b1, a20, 1);
c0 = vmlaq_lane_f32(c0, b2, a21, 0);
c0 = vmlaq_lane_f32(c0, b3, a21, 1);
vst1q_f32(c_ptr2, c0);
c0 = vld1q_f32(c_ptr3);
c0 = vmlaq_lane_f32(c0, b0, a30, 0);
c0 = vmlaq_lane_f32(c0, b1, a30, 1);
c0 = vmlaq_lane_f32(c0, b2, a31, 0);
c0 = vmlaq_lane_f32(c0, b3, a31, 1);
vst1q_f32(c_ptr3, c0);
b_ptr0 += 4;
b_ptr1 += 4;
b_ptr2 += 4;
b_ptr3 += 4;
c_ptr0 += 4;
c_ptr1 += 4;
c_ptr2 += 4;
c_ptr3 += 4;
}
w = (width >> 2) << 2;
}
if (w < width) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 4, 4, width - w, stride_k, stride_w, c_ptr);
}
}
if (k < K) {
......@@ -411,7 +527,7 @@ inline void GemmTile(const float *A,
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
8,
4,
K - k,
width,
stride_k,
......@@ -420,7 +536,6 @@ inline void GemmTile(const float *A,
}
}
if (h < height) {
// TODO(liyin): may use Gemm444
const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
......@@ -433,6 +548,11 @@ inline void GemmTile(const float *A,
stride_w,
c_ptr);
}
#else // cpu
GemmBlock(A, B, height, K, width, stride_k, stride_w, C);
#endif // armv7
#endif // aarch64
}
} // namespace
......
......@@ -7,7 +7,7 @@ package(
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled")
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7")
cc_library(
name = "test",
......@@ -34,7 +34,7 @@ cc_library(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]),
deps = [
"//mace/kernels",
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册