diff --git a/mace/kernels/arm/conv_2d.cc b/mace/kernels/arm/conv_2d.cc index e50cac08bd1f4d7d7693079d0f47f35962ff2e10..d9dd849e4f8d04535a466d0ced470c10afb8f016 100644 --- a/mace/kernels/arm/conv_2d.cc +++ b/mace/kernels/arm/conv_2d.cc @@ -162,7 +162,8 @@ void Conv2dFunctor::operator()(const Tensor *input, if (USE_WINOGRAD && filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 - && dilation_h == 1 && dilation_w == 1) { + && dilation_h == 1 && dilation_w == 1 + && input_channels >= 8 && channels >= 8) { extra_output_height = RoundUp(height, 2); extra_input_height = std::max(padded_input_height, extra_output_height + 2); extra_output_width = RoundUp(width, 2); diff --git a/mace/kernels/arm/conv_winograd.cc b/mace/kernels/arm/conv_winograd.cc index c0509689f832d702ed23afdbd29c5b53abf42773..15c79a7d7bc6f376c3486f693bde26cc83794914 100644 --- a/mace/kernels/arm/conv_winograd.cc +++ b/mace/kernels/arm/conv_winograd.cc @@ -271,44 +271,6 @@ void TransformOutput(const float *input, } } } - -void ConvRef3x3s1(const float *input, - const float *filter, - const index_t batch, - const index_t in_height, - const index_t in_width, - const index_t in_channels, - const index_t out_channels, - float *output) { - index_t out_height = in_height - 2; - index_t out_width = in_width - 2; - -#pragma omp parallel for collapse(4) - for (index_t b = 0; b < batch; ++b) { - for (index_t m = 0; m < out_channels; ++m) { - for (index_t h = 0; h < out_height; ++h) { - for (index_t w = 0; w < out_width; ++w) { - index_t out_offset = - ((b * out_channels + m) * out_height + h) * out_width + w; - output[out_offset] = 0; - for (index_t c = 0; c < in_channels; ++c) { - for (index_t kh = 0; kh < 3; ++kh) { - for (index_t kw = 0; kw < 3; ++kw) { - index_t ih = h + kh; - index_t iw = w + kw; - index_t in_offset = - ((b * in_channels + c) * in_height + ih) * in_width + iw; - index_t - filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw; - output[out_offset] += input[in_offset] * filter[filter_offset]; - } - } - } - } - } - } - } -} } // namespace void WinoGradConv3x3s1(const float *input, @@ -400,5 +362,44 @@ void WinoGradConv3x3s1(const float *input, delete[]transformed_output; } + +void ConvRef3x3s1(const float *input, + const float *filter, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t out_channels, + float *output) { + index_t out_height = in_height - 2; + index_t out_width = in_width - 2; + +#pragma omp parallel for collapse(4) + for (index_t b = 0; b < batch; ++b) { + for (index_t m = 0; m < out_channels; ++m) { + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w < out_width; ++w) { + index_t out_offset = + ((b * out_channels + m) * out_height + h) * out_width + w; + output[out_offset] = 0; + for (index_t c = 0; c < in_channels; ++c) { + for (index_t kh = 0; kh < 3; ++kh) { + for (index_t kw = 0; kw < 3; ++kw) { + index_t ih = h + kh; + index_t iw = w + kw; + index_t in_offset = + ((b * in_channels + c) * in_height + ih) * in_width + iw; + index_t + filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw; + output[out_offset] += input[in_offset] * filter[filter_offset]; + } + } + } + } + } + } + } +} + } // namespace kernels } // namespace mace diff --git a/mace/kernels/arm/conv_winograd.h b/mace/kernels/arm/conv_winograd.h index 7611d65ae5e2a57b4542df40bc4d6bef3d04538d..0b288dd158a1eaace3ffa34ce1b3891ca8f90acc 100644 --- a/mace/kernels/arm/conv_winograd.h +++ b/mace/kernels/arm/conv_winograd.h @@ -36,6 +36,15 @@ void WinoGradConv3x3s1(const float *input, bool is_filter_transformed, float *output); +void ConvRef3x3s1(const float *input, + const float *filter, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t out_channels, + float *output); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 00be4829802ede5fadbc0244917f56fcf0dd6025..b200c7650d596d868022d292029fa29fb1abcc18 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -13,22 +13,6 @@ namespace mace { namespace kernels { namespace { -void GemmRef(const float *A, - const float *B, - const index_t height, - const index_t K, - const index_t width, - float *C) { - memset(C, 0, sizeof(float) * height * width); - for (int i = 0; i < height; ++i) { - for (int j = 0; j < width; ++j) { - for (int k = 0; k < K; ++k) { - C[i * width + j] += A[i * K + k] * B[k * width + j]; - } - } - } -} - inline void GemmBlock(const float *A, const float *B, const index_t height, @@ -49,8 +33,8 @@ inline void GemmBlock(const float *A, // TODO(liyin): may need implement 883 since RGB inline void Gemm884(const float *a_ptr, const float *b_ptr, - index_t stride_w, index_t stride_k, + index_t stride_w, float *c_ptr) { #if defined(MACE_ENABLE_NEON) && defined(__aarch64__) float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, @@ -136,29 +120,300 @@ inline void GemmTile(const float *A, float *C) { index_t h, w, k; for (h = 0; h + 7 < height; h += 8) { - for (w = 0; w + 3 < width; w += 4) { - for (k = 0; k + 7 < K; k += 8) { - const float *a_ptr = A + (h * stride_k + k); + for (k = 0; k + 7 < K; k += 8) { + const float *a_ptr = A + (h * stride_k + k); + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + +#ifdef __clang__ + int nw = width >> 2; + if (nw > 0) { + // load A + float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, + a8, a9, a10, a11, a12, a13, a14, a15; + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 4); + a2 = vld1q_f32(a_ptr + 1 * stride_k); + a3 = vld1q_f32(a_ptr + 1 * stride_k + 4); + a4 = vld1q_f32(a_ptr + 2 * stride_k); + a5 = vld1q_f32(a_ptr + 2 * stride_k + 4); + a6 = vld1q_f32(a_ptr + 3 * stride_k); + a7 = vld1q_f32(a_ptr + 3 * stride_k + 4); + a8 = vld1q_f32(a_ptr + 4 * stride_k); + a9 = vld1q_f32(a_ptr + 4 * stride_k + 4); + a10 = vld1q_f32(a_ptr + 5 * stride_k); + a11 = vld1q_f32(a_ptr + 5 * stride_k + 4); + a12 = vld1q_f32(a_ptr + 6 * stride_k); + a13 = vld1q_f32(a_ptr + 6 * stride_k + 4); + a14 = vld1q_f32(a_ptr + 7 * stride_k); + a15 = vld1q_f32(a_ptr + 7 * stride_k + 4); + + const float *b_ptr0 = B + k * stride_w; + const float *b_ptr1 = B + (k + 1) * stride_w; + const float *b_ptr2 = B + (k + 2) * stride_w; + const float *b_ptr3 = B + (k + 3) * stride_w; + const float *b_ptr4 = B + (k + 4) * stride_w; + const float *b_ptr5 = B + (k + 5) * stride_w; + const float *b_ptr6 = B + (k + 6) * stride_w; + const float *b_ptr7 = B + (k + 7) * stride_w; + + float *c_ptr0 = C + h * stride_w; + float *c_ptr1 = C + (h + 1) * stride_w; + float *c_ptr2 = C + (h + 2) * stride_w; + float *c_ptr3 = C + (h + 3) * stride_w; + float *c_ptr4 = C + (h + 4) * stride_w; + float *c_ptr5 = C + (h + 5) * stride_w; + float *c_ptr6 = C + (h + 6) * stride_w; + float *c_ptr7 = C + (h + 7) * stride_w; + + asm volatile( + "0: \n" + + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v24.4s}, [%1] \n" + + // load b: 0-7 + "prfm pldl1keep, [%9, #128] \n" + "ld1 {v16.4s}, [%9], #16 \n" + + "prfm pldl1keep, [%10, #128] \n" + "ld1 {v17.4s}, [%10], #16 \n" + + "prfm pldl1keep, [%11, #128] \n" + "ld1 {v18.4s}, [%11], #16 \n" + + "prfm pldl1keep, [%12, #128] \n" + "ld1 {v19.4s}, [%12], #16 \n" + + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v25.4s}, [%2] \n" + + "prfm pldl1keep, [%13, #128] \n" + "ld1 {v20.4s}, [%13], #16 \n" + + "prfm pldl1keep, [%14, #128] \n" + "ld1 {v21.4s}, [%14], #16 \n" + + "prfm pldl1keep, [%15, #128] \n" + "ld1 {v22.4s}, [%15], #16 \n" + + "prfm pldl1keep, [%16, #128] \n" + "ld1 {v23.4s}, [%16], #16 \n" + + "prfm pldl1keep, [%3, #128] \n" + "ld1 {v26.4s}, [%3] \n" + + "fmla v24.4s, v16.4s, %34.s[0] \n" + "fmla v24.4s, v17.4s, %34.s[1] \n" + "fmla v24.4s, v18.4s, %34.s[2] \n" + "fmla v24.4s, v19.4s, %34.s[3] \n" + + "fmla v24.4s, v20.4s, %35.s[0] \n" + "fmla v24.4s, v21.4s, %35.s[1] \n" + "fmla v24.4s, v22.4s, %35.s[2] \n" + "fmla v24.4s, v23.4s, %35.s[3] \n" + + "st1 {v24.4s}, [%1], #16 \n" + + "fmla v25.4s, v16.4s, %36.s[0] \n" + "fmla v25.4s, v17.4s, %36.s[1] \n" + "fmla v25.4s, v18.4s, %36.s[2] \n" + "fmla v25.4s, v19.4s, %36.s[3] \n" + + "fmla v25.4s, v20.4s, %37.s[0] \n" + "fmla v25.4s, v21.4s, %37.s[1] \n" + "fmla v25.4s, v22.4s, %37.s[2] \n" + "fmla v25.4s, v23.4s, %37.s[3] \n" + + "prfm pldl1keep, [%4, #128] \n" + "ld1 {v24.4s}, [%4] \n" + + "st1 {v25.4s}, [%2], #16 \n" + + "fmla v26.4s, v16.4s, %38.s[0] \n" + "fmla v26.4s, v17.4s, %38.s[1] \n" + "fmla v26.4s, v18.4s, %38.s[2] \n" + "fmla v26.4s, v19.4s, %38.s[3] \n" + + "fmla v26.4s, v20.4s, %39.s[0] \n" + "fmla v26.4s, v21.4s, %39.s[1] \n" + "fmla v26.4s, v22.4s, %39.s[2] \n" + "fmla v26.4s, v23.4s, %39.s[3] \n" + + "prfm pldl1keep, [%5, #128] \n" + "ld1 {v25.4s}, [%5] \n" + + "st1 {v26.4s}, [%3], #16 \n" + + "fmla v24.4s, v16.4s, %40.s[0] \n" + "fmla v24.4s, v17.4s, %40.s[1] \n" + "fmla v24.4s, v18.4s, %40.s[2] \n" + "fmla v24.4s, v19.4s, %40.s[3] \n" + + "fmla v24.4s, v20.4s, %41.s[0] \n" + "fmla v24.4s, v21.4s, %41.s[1] \n" + "fmla v24.4s, v22.4s, %41.s[2] \n" + "fmla v24.4s, v23.4s, %41.s[3] \n" + + "prfm pldl1keep, [%6, #128] \n" + "ld1 {v26.4s}, [%6] \n" + + "st1 {v24.4s}, [%4], #16 \n" + + "fmla v25.4s, v16.4s, %42.s[0] \n" + "fmla v25.4s, v17.4s, %42.s[1] \n" + "fmla v25.4s, v18.4s, %42.s[2] \n" + "fmla v25.4s, v19.4s, %42.s[3] \n" + + "fmla v25.4s, v20.4s, %43.s[0] \n" + "fmla v25.4s, v21.4s, %43.s[1] \n" + "fmla v25.4s, v22.4s, %43.s[2] \n" + "fmla v25.4s, v23.4s, %43.s[3] \n" + + "prfm pldl1keep, [%7, #128] \n" + "ld1 {v24.4s}, [%7] \n" + + "st1 {v25.4s}, [%5], #16 \n" + + "fmla v26.4s, v16.4s, %44.s[0] \n" + "fmla v26.4s, v17.4s, %44.s[1] \n" + "fmla v26.4s, v18.4s, %44.s[2] \n" + "fmla v26.4s, v19.4s, %44.s[3] \n" + + "fmla v26.4s, v20.4s, %45.s[0] \n" + "fmla v26.4s, v21.4s, %45.s[1] \n" + "fmla v26.4s, v22.4s, %45.s[2] \n" + "fmla v26.4s, v23.4s, %45.s[3] \n" + + "prfm pldl1keep, [%8, #128] \n" + "ld1 {v25.4s}, [%8] \n" + + "st1 {v26.4s}, [%6], #16 \n" + + "fmla v24.4s, v16.4s, %46.s[0] \n" + "fmla v24.4s, v17.4s, %46.s[1] \n" + "fmla v24.4s, v18.4s, %46.s[2] \n" + "fmla v24.4s, v19.4s, %46.s[3] \n" + + "fmla v24.4s, v20.4s, %47.s[0] \n" + "fmla v24.4s, v21.4s, %47.s[1] \n" + "fmla v24.4s, v22.4s, %47.s[2] \n" + "fmla v24.4s, v23.4s, %47.s[3] \n" + + "st1 {v24.4s}, [%7], #16 \n" + + "fmla v25.4s, v16.4s, %48.s[0] \n" + "fmla v25.4s, v17.4s, %48.s[1] \n" + "fmla v25.4s, v18.4s, %48.s[2] \n" + "fmla v25.4s, v19.4s, %48.s[3] \n" + + "fmla v25.4s, v20.4s, %49.s[0] \n" + "fmla v25.4s, v21.4s, %49.s[1] \n" + "fmla v25.4s, v22.4s, %49.s[2] \n" + "fmla v25.4s, v23.4s, %49.s[3] \n" + + "st1 {v25.4s}, [%8], #16 \n" + + "subs %w0, %w0, #1 \n" + "bne 0b \n" + : "=r"(nw), // 0 + "=r"(c_ptr0), // 1 + "=r"(c_ptr1), // 2 + "=r"(c_ptr2), // 3 + "=r"(c_ptr3), // 4 + "=r"(c_ptr4), // 5 + "=r"(c_ptr5), // 6 + "=r"(c_ptr6), // 7 + "=r"(c_ptr7), // 8 + "=r"(b_ptr0), // 9 + "=r"(b_ptr1), // 10 + "=r"(b_ptr2), // 11 + "=r"(b_ptr3), // 12 + "=r"(b_ptr4), // 13 + "=r"(b_ptr5), // 14 + "=r"(b_ptr6), // 15 + "=r"(b_ptr7) // 16 + : "0"(nw), // 17 + "1"(c_ptr0), // 18 + "2"(c_ptr1), // 19 + "3"(c_ptr2), // 20 + "4"(c_ptr3), // 21 + "5"(c_ptr4), // 22 + "6"(c_ptr5), // 23 + "7"(c_ptr6), // 24 + "8"(c_ptr7), // 25 + "9"(b_ptr0), // 26 + "10"(b_ptr1), // 27 + "11"(b_ptr2), // 28 + "12"(b_ptr3), // 29 + "13"(b_ptr4), // 30 + "14"(b_ptr5), // 31 + "15"(b_ptr6), // 32 + "16"(b_ptr7), // 33 + "w"(a0), // 34 + "w"(a1), // 35 + "w"(a2), // 36 + "w"(a3), // 37 + "w"(a4), // 38 + "w"(a5), // 39 + "w"(a6), // 40 + "w"(a7), // 41 + "w"(a8), // 42 + "w"(a9), // 43 + "w"(a10), // 44 + "w"(a11), // 45 + "w"(a12), // 46 + "w"(a13), // 47 + "w"(a14), // 48 + "w"(a15) // 49 + : "cc", "memory", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26" + ); + + w = (width >> 2) << 2; + } +#else // gcc + for (w = 0; w + 3 < width; w += 4) { + const float *b_ptr = B + (k * stride_w + w); + float *c_ptr = C + (h * stride_w + w); + Gemm884(a_ptr, b_ptr, stride_k, stride_w, c_ptr); + } +#endif + +#else + for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); - Gemm884(a_ptr, b_ptr, stride_w, stride_k, c_ptr); + GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr); } - if (k < K) { +#endif + + if (w < width) { const float *a_ptr = A + (h * stride_k + k); const float *b_ptr = B + (k * stride_w + w); float *c_ptr = C + (h * stride_w + w); - GemmBlock(a_ptr, b_ptr, 8, K - k, 4, stride_k, stride_w, c_ptr); + GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_k, stride_w, c_ptr); } } - if (w < width) { - const float *a_ptr = A + h * stride_k; - const float *b_ptr = B + w; - float *c_ptr = C + (h * stride_w + w); + if (k < K) { + const float *a_ptr = A + (h * stride_k + k); + const float *b_ptr = B + k * stride_w; + float *c_ptr = C + h * stride_w; GemmBlock(a_ptr, b_ptr, 8, - K, - width - w, + K - k, + width, stride_k, stride_w, c_ptr); @@ -243,5 +498,21 @@ void Gemm(const float *A, } // n } +void GemmRef(const float *A, + const float *B, + const index_t height, + const index_t K, + const index_t width, + float *C) { + memset(C, 0, sizeof(float) * height * width); + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + for (int k = 0; k < K; ++k) { + C[i * width + j] += A[i * K + k] * B[k * width + j]; + } + } + } +} + } // namespace kernels } // namespace mace diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index d17eab83e12d5531eed8bdaddd6af352b179b327..eec69e8ef7675a28a3fa79a8f5bf5f8e957d6865 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -22,6 +22,13 @@ void Gemm(const float *A, const index_t width, float *C); +void GemmRef(const float *A, + const float *B, + const index_t height, + const index_t K, + const index_t width, + float *C); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 9e4b964daf5922a6fd0912ca65d9b01225aaf2fb..dacb93eebcab3aae16035477d4a591780f6b32a6 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -31,7 +31,7 @@ TEST(GEMMTest, gemm) { [&gen, &nd] { return nd(gen); }); - kernels::Gemm(A, B, N, K, M, C); + kernels::Gemm(A, B, 1, N, K, M, C); kernels::GemmRef(A, B, N, K, M, C_ref); for (int i = 0; i < N * M; ++i) {