提交 db75d542 编写于 作者: 吴承辉

Merge branch 'gemm-asm' into 'master'

Implement ASM GEMM

See merge request !354
......@@ -162,7 +162,8 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
if (USE_WINOGRAD && filter_h == 3 && filter_w == 3 && stride_h == 1
&& stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
&& dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 2);
......
......@@ -271,44 +271,6 @@ void TransformOutput(const float *input,
}
}
}
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
#pragma omp parallel for collapse(4)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
output[out_offset] = 0;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < 3; ++kh) {
for (index_t kw = 0; kw < 3; ++kw) {
index_t ih = h + kh;
index_t iw = w + kw;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t
filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace
void WinoGradConv3x3s1(const float *input,
......@@ -400,5 +362,44 @@ void WinoGradConv3x3s1(const float *input,
delete[]transformed_output;
}
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
#pragma omp parallel for collapse(4)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
output[out_offset] = 0;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < 3; ++kh) {
for (index_t kw = 0; kw < 3; ++kw) {
index_t ih = h + kh;
index_t iw = w + kw;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t
filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace kernels
} // namespace mace
......@@ -36,6 +36,15 @@ void WinoGradConv3x3s1(const float *input,
bool is_filter_transformed,
float *output);
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output);
} // namespace kernels
} // namespace mace
......
......@@ -13,22 +13,6 @@ namespace mace {
namespace kernels {
namespace {
void GemmRef(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
float *C) {
memset(C, 0, sizeof(float) * height * width);
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * width + j] += A[i * K + k] * B[k * width + j];
}
}
}
}
inline void GemmBlock(const float *A,
const float *B,
const index_t height,
......@@ -49,8 +33,8 @@ inline void GemmBlock(const float *A,
// TODO(liyin): may need implement 883 since RGB
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
index_t stride_w,
index_t stride_k,
index_t stride_w,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
......@@ -136,29 +120,300 @@ inline void GemmTile(const float *A,
float *C) {
index_t h, w, k;
for (h = 0; h + 7 < height; h += 8) {
for (w = 0; w + 3 < width; w += 4) {
for (k = 0; k + 7 < K; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
for (k = 0; k + 7 < K; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#ifdef __clang__
int nw = width >> 2;
if (nw > 0) {
// load A
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7,
a8, a9, a10, a11, a12, a13, a14, a15;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_k);
a15 = vld1q_f32(a_ptr + 7 * stride_k + 4);
const float *b_ptr0 = B + k * stride_w;
const float *b_ptr1 = B + (k + 1) * stride_w;
const float *b_ptr2 = B + (k + 2) * stride_w;
const float *b_ptr3 = B + (k + 3) * stride_w;
const float *b_ptr4 = B + (k + 4) * stride_w;
const float *b_ptr5 = B + (k + 5) * stride_w;
const float *b_ptr6 = B + (k + 6) * stride_w;
const float *b_ptr7 = B + (k + 7) * stride_w;
float *c_ptr0 = C + h * stride_w;
float *c_ptr1 = C + (h + 1) * stride_w;
float *c_ptr2 = C + (h + 2) * stride_w;
float *c_ptr3 = C + (h + 3) * stride_w;
float *c_ptr4 = C + (h + 4) * stride_w;
float *c_ptr5 = C + (h + 5) * stride_w;
float *c_ptr6 = C + (h + 6) * stride_w;
float *c_ptr7 = C + (h + 7) * stride_w;
asm volatile(
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v24.4s}, [%1] \n"
// load b: 0-7
"prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%10, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n"
"prfm pldl1keep, [%11, #128] \n"
"ld1 {v18.4s}, [%11], #16 \n"
"prfm pldl1keep, [%12, #128] \n"
"ld1 {v19.4s}, [%12], #16 \n"
"prfm pldl1keep, [%2, #128] \n"
"ld1 {v25.4s}, [%2] \n"
"prfm pldl1keep, [%13, #128] \n"
"ld1 {v20.4s}, [%13], #16 \n"
"prfm pldl1keep, [%14, #128] \n"
"ld1 {v21.4s}, [%14], #16 \n"
"prfm pldl1keep, [%15, #128] \n"
"ld1 {v22.4s}, [%15], #16 \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v23.4s}, [%16], #16 \n"
"prfm pldl1keep, [%3, #128] \n"
"ld1 {v26.4s}, [%3] \n"
"fmla v24.4s, v16.4s, %34.s[0] \n"
"fmla v24.4s, v17.4s, %34.s[1] \n"
"fmla v24.4s, v18.4s, %34.s[2] \n"
"fmla v24.4s, v19.4s, %34.s[3] \n"
"fmla v24.4s, v20.4s, %35.s[0] \n"
"fmla v24.4s, v21.4s, %35.s[1] \n"
"fmla v24.4s, v22.4s, %35.s[2] \n"
"fmla v24.4s, v23.4s, %35.s[3] \n"
"st1 {v24.4s}, [%1], #16 \n"
"fmla v25.4s, v16.4s, %36.s[0] \n"
"fmla v25.4s, v17.4s, %36.s[1] \n"
"fmla v25.4s, v18.4s, %36.s[2] \n"
"fmla v25.4s, v19.4s, %36.s[3] \n"
"fmla v25.4s, v20.4s, %37.s[0] \n"
"fmla v25.4s, v21.4s, %37.s[1] \n"
"fmla v25.4s, v22.4s, %37.s[2] \n"
"fmla v25.4s, v23.4s, %37.s[3] \n"
"prfm pldl1keep, [%4, #128] \n"
"ld1 {v24.4s}, [%4] \n"
"st1 {v25.4s}, [%2], #16 \n"
"fmla v26.4s, v16.4s, %38.s[0] \n"
"fmla v26.4s, v17.4s, %38.s[1] \n"
"fmla v26.4s, v18.4s, %38.s[2] \n"
"fmla v26.4s, v19.4s, %38.s[3] \n"
"fmla v26.4s, v20.4s, %39.s[0] \n"
"fmla v26.4s, v21.4s, %39.s[1] \n"
"fmla v26.4s, v22.4s, %39.s[2] \n"
"fmla v26.4s, v23.4s, %39.s[3] \n"
"prfm pldl1keep, [%5, #128] \n"
"ld1 {v25.4s}, [%5] \n"
"st1 {v26.4s}, [%3], #16 \n"
"fmla v24.4s, v16.4s, %40.s[0] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n"
"fmla v24.4s, v18.4s, %40.s[2] \n"
"fmla v24.4s, v19.4s, %40.s[3] \n"
"fmla v24.4s, v20.4s, %41.s[0] \n"
"fmla v24.4s, v21.4s, %41.s[1] \n"
"fmla v24.4s, v22.4s, %41.s[2] \n"
"fmla v24.4s, v23.4s, %41.s[3] \n"
"prfm pldl1keep, [%6, #128] \n"
"ld1 {v26.4s}, [%6] \n"
"st1 {v24.4s}, [%4], #16 \n"
"fmla v25.4s, v16.4s, %42.s[0] \n"
"fmla v25.4s, v17.4s, %42.s[1] \n"
"fmla v25.4s, v18.4s, %42.s[2] \n"
"fmla v25.4s, v19.4s, %42.s[3] \n"
"fmla v25.4s, v20.4s, %43.s[0] \n"
"fmla v25.4s, v21.4s, %43.s[1] \n"
"fmla v25.4s, v22.4s, %43.s[2] \n"
"fmla v25.4s, v23.4s, %43.s[3] \n"
"prfm pldl1keep, [%7, #128] \n"
"ld1 {v24.4s}, [%7] \n"
"st1 {v25.4s}, [%5], #16 \n"
"fmla v26.4s, v16.4s, %44.s[0] \n"
"fmla v26.4s, v17.4s, %44.s[1] \n"
"fmla v26.4s, v18.4s, %44.s[2] \n"
"fmla v26.4s, v19.4s, %44.s[3] \n"
"fmla v26.4s, v20.4s, %45.s[0] \n"
"fmla v26.4s, v21.4s, %45.s[1] \n"
"fmla v26.4s, v22.4s, %45.s[2] \n"
"fmla v26.4s, v23.4s, %45.s[3] \n"
"prfm pldl1keep, [%8, #128] \n"
"ld1 {v25.4s}, [%8] \n"
"st1 {v26.4s}, [%6], #16 \n"
"fmla v24.4s, v16.4s, %46.s[0] \n"
"fmla v24.4s, v17.4s, %46.s[1] \n"
"fmla v24.4s, v18.4s, %46.s[2] \n"
"fmla v24.4s, v19.4s, %46.s[3] \n"
"fmla v24.4s, v20.4s, %47.s[0] \n"
"fmla v24.4s, v21.4s, %47.s[1] \n"
"fmla v24.4s, v22.4s, %47.s[2] \n"
"fmla v24.4s, v23.4s, %47.s[3] \n"
"st1 {v24.4s}, [%7], #16 \n"
"fmla v25.4s, v16.4s, %48.s[0] \n"
"fmla v25.4s, v17.4s, %48.s[1] \n"
"fmla v25.4s, v18.4s, %48.s[2] \n"
"fmla v25.4s, v19.4s, %48.s[3] \n"
"fmla v25.4s, v20.4s, %49.s[0] \n"
"fmla v25.4s, v21.4s, %49.s[1] \n"
"fmla v25.4s, v22.4s, %49.s[2] \n"
"fmla v25.4s, v23.4s, %49.s[3] \n"
"st1 {v25.4s}, [%8], #16 \n"
"subs %w0, %w0, #1 \n"
"bne 0b \n"
: "=r"(nw), // 0
"=r"(c_ptr0), // 1
"=r"(c_ptr1), // 2
"=r"(c_ptr2), // 3
"=r"(c_ptr3), // 4
"=r"(c_ptr4), // 5
"=r"(c_ptr5), // 6
"=r"(c_ptr6), // 7
"=r"(c_ptr7), // 8
"=r"(b_ptr0), // 9
"=r"(b_ptr1), // 10
"=r"(b_ptr2), // 11
"=r"(b_ptr3), // 12
"=r"(b_ptr4), // 13
"=r"(b_ptr5), // 14
"=r"(b_ptr6), // 15
"=r"(b_ptr7) // 16
: "0"(nw), // 17
"1"(c_ptr0), // 18
"2"(c_ptr1), // 19
"3"(c_ptr2), // 20
"4"(c_ptr3), // 21
"5"(c_ptr4), // 22
"6"(c_ptr5), // 23
"7"(c_ptr6), // 24
"8"(c_ptr7), // 25
"9"(b_ptr0), // 26
"10"(b_ptr1), // 27
"11"(b_ptr2), // 28
"12"(b_ptr3), // 29
"13"(b_ptr4), // 30
"14"(b_ptr5), // 31
"15"(b_ptr6), // 32
"16"(b_ptr7), // 33
"w"(a0), // 34
"w"(a1), // 35
"w"(a2), // 36
"w"(a3), // 37
"w"(a4), // 38
"w"(a5), // 39
"w"(a6), // 40
"w"(a7), // 41
"w"(a8), // 42
"w"(a9), // 43
"w"(a10), // 44
"w"(a11), // 45
"w"(a12), // 46
"w"(a13), // 47
"w"(a14), // 48
"w"(a15) // 49
: "cc", "memory",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26"
);
w = (width >> 2) << 2;
}
#else // gcc
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr);
}
#endif
#else
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_w, stride_k, c_ptr);
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr);
}
if (k < K) {
#endif
if (w < width) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, K - k, 4, stride_k, stride_w, c_ptr);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr);
}
}
if (w < width) {
const float *a_ptr = A + h * stride_k;
const float *b_ptr = B + w;
float *c_ptr = C + (h * stride_w + w);
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + k * stride_w;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
8,
K,
width - w,
K - k,
width,
stride_k,
stride_w,
c_ptr);
......@@ -243,5 +498,21 @@ void Gemm(const float *A,
} // n
}
void GemmRef(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
float *C) {
memset(C, 0, sizeof(float) * height * width);
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * width + j] += A[i * K + k] * B[k * width + j];
}
}
}
}
} // namespace kernels
} // namespace mace
......@@ -22,6 +22,13 @@ void Gemm(const float *A,
const index_t width,
float *C);
void GemmRef(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
float *C);
} // namespace kernels
} // namespace mace
......
......@@ -31,7 +31,7 @@ TEST(GEMMTest, gemm) {
[&gen, &nd] {
return nd(gen);
});
kernels::Gemm(A, B, N, K, M, C);
kernels::Gemm(A, B, 1, N, K, M, C);
kernels::GemmRef(A, B, N, K, M, C_ref);
for (int i = 0; i < N * M; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册