提交 7ca3d579 编写于 作者: M Megvii Engine Team

feat(dnn): make mk4 and mk8 matmul for winograd both on aarch64 and armv7 supports n=1

GitOrigin-RevId: 0f64b9f70f5f010696cec06ba37113c0d7dd178f
上级 54d18115
......@@ -214,8 +214,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable(
kern_size_param.B_type == dtype::Float32() &&
kern_size_param.A_type == dtype::Float32() &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.N % 4 == 0;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
......@@ -330,8 +329,7 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable(
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float16() &&
kern_size_param.format == param::MatrixMul::Format::MK8 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.N % 4 == 0;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
......@@ -918,8 +916,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable(
kern_size_param.B_type == dtype::Int16() &&
kern_size_param.A_type == dtype::Int16() &&
kern_size_param.format == param::MatrixMul::Format::MK8 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.N % 4 == 0;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
......
......@@ -21,6 +21,76 @@ using namespace aarch64::matmul;
namespace {
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
LDB *= sizeof(dt_float16);
asm volatile(
".arch armv8.2-a+fp16\n"
"subs %w[K], %w[K], #8\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n"
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v24.8h, v16.8h, v0.h[0]\n"
"fmla v25.8h, v17.8h, v0.h[1]\n"
"fmla v26.8h, v18.8h, v0.h[2]\n"
"fmla v27.8h, v19.8h, v0.h[3]\n"
"beq 2f\n"
"1:\n"
"ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n"
"fmla v28.8h, v20.8h, v0.h[4]\n"
"fmla v29.8h, v21.8h, v0.h[5]\n"
"fmla v30.8h, v22.8h, v0.h[6]\n"
"fmla v31.8h, v23.8h, v0.h[7]\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n"
"fmla v24.8h, v16.8h, v0.h[0]\n"
"fmla v25.8h, v17.8h, v0.h[1]\n"
"fmla v26.8h, v18.8h, v0.h[2]\n"
"fmla v27.8h, v19.8h, v0.h[3]\n"
"subs %w[K], %w[K], #8\n"
"bne 1b\n"
"2:\n"
"fmla v28.8h, v20.8h, v0.h[4]\n"
"fmla v29.8h, v21.8h, v0.h[5]\n"
"fmla v30.8h, v22.8h, v0.h[6]\n"
"fmla v31.8h, v23.8h, v0.h[7]\n"
"fadd v24.8h, v24.8h, v25.8h\n"
"fadd v26.8h, v26.8h, v27.8h\n"
"fadd v28.8h, v28.8h, v29.8h\n"
"fadd v30.8h, v30.8h, v31.8h\n"
"fadd v24.8h, v24.8h, v26.8h\n"
"fadd v28.8h, v28.8h, v30.8h\n"
"fadd v24.8h, v24.8h, v28.8h\n"
"st1 {v24.4s}, [%[output]], 16\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
}
// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v0-v3
......@@ -416,7 +486,7 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
constexpr static size_t NB = 8;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
......@@ -428,8 +498,17 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
if (N - n >= 4) {
kern_8x4(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK;
output += MB * CALCBLK;
n += 4;
}
while (n < N) {
kern_8x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
......
......@@ -20,6 +20,54 @@ using namespace aarch64::matmul;
namespace {
void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
LDB *= sizeof(float);
asm volatile(
"subs %w[K], %w[K], #4\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n"
"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"prfm pstl1keep, [%[b_ptr]]\n"
"fmla v16.4s, v4.4s, v0.s[0]\n"
"fmla v17.4s, v5.4s, v0.s[1]\n"
"beq 2f\n"
"1:\n"
"ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n"
"fmla v18.4s, v6.4s, v0.s[2]\n"
"fmla v19.4s, v7.4s, v0.s[3]\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"prfm pstl1keep, [%[b_ptr]]\n"
"ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n"
"fmla v16.4s, v4.4s, v0.s[0]\n"
"fmla v17.4s, v5.4s, v0.s[1]\n"
"subs %w[K], %w[K], #4\n"
"bne 1b\n"
"2:\n"
"fmla v18.4s, v6.4s, v0.s[2]\n"
"fmla v19.4s, v7.4s, v0.s[3]\n"
"fadd v16.4s, v16.4s, v18.4s\n"
"fadd v17.4s, v17.4s, v19.4s\n"
"fadd v16.4s, v16.4s, v17.4s\n"
"st1 {v16.4s}, [%[output]], 16\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc",
"memory");
}
// Overview of register layout:
//
// A 4x4 block of A is stored in register v4-v7
......@@ -117,7 +165,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "cc", "memory");
}
// Overview of register layout:
......@@ -535,7 +584,7 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B,
constexpr static size_t NB = 16;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4)
for (size_t m = 0; m < M; m += MB) {
......@@ -547,21 +596,23 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B,
cur_B += KB * NB;
output += MB * NB;
}
switch (N - n) {
case 4:
kern_4x4(A, cur_B, LDB, K, output);
break;
case 8:
kern_4x8(A, cur_B, LDB, K, output);
break;
case 12:
kern_4x8(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK * 2;
output += MB * CALCBLK * 2;
kern_4x4(A, cur_B, LDB, K, output);
break;
default:
break;
if (N - n >= 8) {
kern_4x8(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK * 2;
output += MB * CALCBLK * 2;
n += 8;
}
if (N - n >= 4) {
kern_4x4(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK;
output += MB * CALCBLK;
n += 4;
}
while (n < N) {
kern_4x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
......
......@@ -20,6 +20,82 @@ using namespace aarch64::matmul;
namespace {
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here.
LDB *= sizeof(dt_int16);
asm volatile(
"subs %w[K], %w[K], #8\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"smull v16.4s, v24.4h, v0.h[0]\n"
"smull2 v17.4s, v24.8h, v0.h[0]\n"
"smull v18.4s, v25.4h, v0.h[1]\n"
"smull2 v19.4s, v25.8h, v0.h[1]\n"
"smull v20.4s, v26.4h, v0.h[2]\n"
"smull2 v21.4s, v26.8h, v0.h[2]\n"
"smull v22.4s, v27.4h, v0.h[3]\n"
"smull2 v23.4s, v27.8h, v0.h[3]\n"
"beq 2f\n"
"1:\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n"
"smlal v16.4s, v28.4h, v0.h[4]\n"
"smlal2 v17.4s, v28.8h, v0.h[4]\n"
"smlal v18.4s, v29.4h, v0.h[5]\n"
"smlal2 v19.4s, v29.8h, v0.h[5]\n"
"smlal v20.4s, v30.4h, v0.h[6]\n"
"smlal2 v21.4s, v30.8h, v0.h[6]\n"
"smlal v22.4s, v31.4h, v0.h[7]\n"
"smlal2 v23.4s, v31.8h, v0.h[7]\n"
"ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n"
"ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n"
"smlal v16.4s, v24.4h, v0.h[0]\n"
"smlal2 v17.4s, v24.8h, v0.h[0]\n"
"smlal v18.4s, v25.4h, v0.h[1]\n"
"smlal2 v19.4s, v25.8h, v0.h[1]\n"
"smlal v20.4s, v26.4h, v0.h[2]\n"
"smlal2 v21.4s, v26.8h, v0.h[2]\n"
"smlal v22.4s, v27.4h, v0.h[3]\n"
"smlal2 v23.4s, v27.8h, v0.h[3]\n"
"subs %w[K], %w[K], #8\n"
"bne 1b\n"
"2:\n"
"smlal v16.4s, v28.4h, v0.h[4]\n"
"smlal2 v17.4s, v28.8h, v0.h[4]\n"
"smlal v18.4s, v29.4h, v0.h[5]\n"
"smlal2 v19.4s, v29.8h, v0.h[5]\n"
"smlal v20.4s, v30.4h, v0.h[6]\n"
"smlal2 v21.4s, v30.8h, v0.h[6]\n"
"smlal v22.4s, v31.4h, v0.h[7]\n"
"smlal2 v23.4s, v31.8h, v0.h[7]\n"
"add v16.4s, v16.4s, v18.4s\n"
"add v20.4s, v20.4s, v22.4s\n"
"add v17.4s, v17.4s, v19.4s\n"
"add v21.4s, v21.4s, v23.4s\n"
"add v16.4s, v16.4s, v20.4s\n"
"add v17.4s, v17.4s, v21.4s\n"
"st1 {v16.4s, v17.4s}, [%[output]], 32\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}
// Overview of register layout:
//
// A 8x1 cell of Lhs is stored in 16bit in v24-v27
......@@ -636,7 +712,7 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
constexpr static size_t NB = 8;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
......@@ -648,8 +724,17 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
if (N - n >= 4) {
kern_8x4(A, cur_B, LDB, K, output);
cur_B += KB * CALCBLK;
output += MB * CALCBLK;
n += 4;
}
while (n < N) {
kern_8x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
......
......@@ -390,7 +390,7 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf,
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16,
megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16,
bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start,
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
......
......@@ -875,8 +875,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x8::usable(
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() &&
kern_size_param.N % 4 == 0 && !kern_size_param.trA &&
!kern_size_param.trB;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace(
......@@ -911,8 +910,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable(
kern_size_param.A_type == dtype::Int16() &&
kern_size_param.B_type == dtype::Int16() &&
kern_size_param.C_type == dtype::Int32() &&
kern_size_param.N % 4 == 0 && !kern_size_param.trA &&
!kern_size_param.trB;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace(
......@@ -969,8 +967,7 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable(
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float16() &&
kern_size_param.format == param::MatrixMul::Format::MK8 &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.N % 4 == 0;
!kern_size_param.trA && !kern_size_param.trB;
}
size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace(
......
......@@ -21,6 +21,66 @@ using namespace armv7::matmul;
namespace {
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
LDB = (LDB - 4) * sizeof(dt_float16);
asm volatile(
"subs %[K], #8\n"
"vld1.32 {d0}, [%[b_ptr]]!\n"
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n"
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n"
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n"
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n"
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n"
"vmul.f16 q12, q4, d0[0]\n"
"vmul.f16 q13, q5, d0[1]\n"
"vmul.f16 q14, q6, d0[2]\n"
"vmul.f16 q15, q7, d0[3]\n"
"beq 2f\n"
"1:\n"
"vmla.f16 q12, q8, d1[0]\n"
"vld1.32 {d0}, [%[b_ptr]]!\n"
"vmla.f16 q13, q9, d1[1]\n"
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n"
"vmla.f16 q14, q10, d1[2]\n"
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n"
"vmla.f16 q15, q11, d1[3]\n"
"vmla.f16 q12, q4, d0[0]\n"
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n"
"vmla.f16 q13, q5, d0[1]\n"
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n"
"vmla.f16 q14, q6, d0[2]\n"
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n"
"vmla.f16 q15, q7, d0[3]\n"
"subs %[K], #8\n"
"bne 1b\n"
"2:\n"
"vmla.f16 q12, q8, d1[0]\n"
"vmla.f16 q13, q9, d1[1]\n"
"vmla.f16 q14, q10, d1[2]\n"
"vmla.f16 q15, q11, d1[3]\n"
"vadd.f16 q12, q12, q14\n"
"vadd.f16 q13, q13, q15\n"
"vadd.f16 q12, q12, q13\n"
"vst1.32 {d24, d25}, [%[output]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15",
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24",
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory");
}
// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v4-v11
......@@ -45,7 +105,7 @@ namespace {
// | v3[0-7]| |v15[0-7]|
// +--------+ +--------+--------+
// Accumulator
void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48
//! here.
......@@ -179,19 +239,25 @@ void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA,
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 4;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
dt_float16* output = C + (m / MB) * LDC;
const dt_float16* cur_B = B;
for (size_t n = 0; n < N; n += NB) {
kern_4x8(A, cur_B, LDB, K, output);
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x4(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
while (n < N) {
kern_8x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
}
......
......@@ -20,6 +20,58 @@ using namespace armv7::matmul;
namespace {
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = (LDB - 4) * sizeof(float);
asm volatile(
"subs %[K], %[K], #4\n"
"vld1.32 {d8-d11}, [%[A]]!\n"
"vld1.32 {d12-d15}, [%[A]]!\n"
"veor q8, q8 \n"
"veor q9, q9 \n"
"veor q10, q10 \n"
"veor q11, q11 \n"
"vld1.32 {d0-d1}, [%[B]]!\n"
"vmla.f32 q8, q4, d0[0]\n"
"vmla.f32 q9, q5, d0[1]\n"
"beq 2f\n"
"1:\n"
"vld1.32 {d8-d11}, [%[A]]!\n"
"vmla.f32 q10, q6, d1[0]\n"
"vmla.f32 q11, q7, d1[1]\n"
"add %[B], %[B], %[LDB]\n"
"vld1.32 {d0-d1}, [%[B]]!\n"
"vld1.32 {d12-d15}, [%[A]]!\n"
"vmla.f32 q8, q4, d0[0]\n"
"vmla.f32 q9, q5, d0[1]\n"
"subs %[K], %[K], #4\n"
"bne 1b\n"
"2:\n"
"vmla.f32 q10, q6, d1[0]\n"
"vmla.f32 q11, q7, d1[1]\n"
"vadd.f32 q8, q8, q10\n"
"vadd.f32 q9, q9, q11\n"
"vadd.f32 q8, q8, q9\n"
"vst1.32 {d16, d17}, [%[C]]!\n"
: [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C)
: [ LDB ] "r"(LDB)
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15",
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc",
"memory");
}
// Overview of register layout:
//
// A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time
......@@ -268,9 +320,9 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B,
constexpr size_t MB = 4;
constexpr size_t KB = 4;
constexpr size_t NB = 8;
constexpr size_t CALCBLK = 4;
constexpr size_t NB_HALF = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
......@@ -282,8 +334,17 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B,
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
if (N - n >= 4) {
kern_4x4(A, cur_B, LDB, K, output);
cur_B += KB * NB_HALF;
output += MB * NB_HALF;
n += 4;
}
while (n < N) {
kern_4x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
......
......@@ -20,6 +20,91 @@ using namespace armv7::matmul;
namespace {
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16
//! here.
LDB = (LDB - 4) * sizeof(dt_int16);
asm volatile(
"subs %[K], #8\n"
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n"
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n"
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n"
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n"
"vld1.32 {d0}, [%[b_ptr]]!\n"
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n"
"vmull.s16 q12, d8, d0[0]\n"
"vmull.s16 q13, d9, d0[0]\n"
"vmull.s16 q14, d10, d0[1]\n"
"vmull.s16 q15, d11, d0[1]\n"
"vmlal.s16 q12, d12, d0[2]\n"
"vmlal.s16 q13, d13, d0[2]\n"
"vmlal.s16 q14, d14, d0[3]\n"
"vmlal.s16 q15, d15, d0[3]\n"
"beq 2f\n"
"1:\n"
"vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n"
"vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n"
"vld1.32 {d0}, [%[b_ptr]]!\n"
"vmlal.s16 q12, d16, d1[0]\n"
"vmlal.s16 q13, d17, d1[0]\n"
"vmlal.s16 q14, d18, d1[1]\n"
"vmlal.s16 q15, d19, d1[1]\n"
"vmlal.s16 q12, d20, d1[2]\n"
"vmlal.s16 q13, d21, d1[2]\n"
"vmlal.s16 q14, d22, d1[3]\n"
"vmlal.s16 q15, d23, d1[3]\n"
"vld1.32 {d1}, [%[b_ptr]], %[LDB]\n"
"vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n"
"vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n"
"vmlal.s16 q12, d8, d0[0]\n"
"vmlal.s16 q13, d9, d0[0]\n"
"vmlal.s16 q14, d10, d0[1]\n"
"vmlal.s16 q15, d11, d0[1]\n"
"vmlal.s16 q12, d12, d0[2]\n"
"vmlal.s16 q13, d13, d0[2]\n"
"vmlal.s16 q14, d14, d0[3]\n"
"vmlal.s16 q15, d15, d0[3]\n"
"subs %[K], %[K], #8\n"
"bne 1b\n"
"2:\n"
"vmlal.s16 q12, d16, d1[0]\n"
"vmlal.s16 q13, d17, d1[0]\n"
"vmlal.s16 q14, d18, d1[1]\n"
"vmlal.s16 q15, d19, d1[1]\n"
"vmlal.s16 q12, d20, d1[2]\n"
"vmlal.s16 q13, d21, d1[2]\n"
"vmlal.s16 q14, d22, d1[3]\n"
"vmlal.s16 q15, d23, d1[3]\n"
"vadd.s32 q12, q12, q14\n"
"vadd.s32 q13, q13, q15\n"
"vst1.32 {d24, d25, d26, d27}, [%[output]]!\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15",
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24",
"d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory");
}
// Overview of register layout:
//
// A 4x8 cell of Rhs is stored in 16bit in q0-q3
......@@ -40,7 +125,7 @@ namespace {
// | q3[0-7]| |q14[0-3]|v15[0-3]|
// +--------+ +--------+--------+
// Accumulator
void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 16 number from B, but the pos add 16 * 2, so we minus 16
//! here.
......@@ -247,19 +332,25 @@ void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 4;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
dt_int32* output = C + (m / MB) * LDC;
const dt_int16* cur_B = B;
for (size_t n = 0; n < N; n += NB) {
kern_4x8(A, cur_B, LDB, K, output);
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x4(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
while (n < N) {
kern_8x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
}
......
......@@ -427,9 +427,6 @@ public:
"The winograd remain oc is not times of OC_BLOCK_SIZE");
if (format == param::MatrixMul::Format::MK4 ||
format == param::MatrixMul::Format::MK8) {
#if !MEGDNN_X86
nr_tiles_in_unit = round_up<size_t>(nr_tiles_in_unit, 4);
#endif
megdnn_assert(nr_tiles_in_unit <= unit_tile_size,
"nr_tiles_in_unit: %zu TILE_SIZE:%zu",
nr_tiles_in_unit, unit_tile_size);
......
......@@ -38,10 +38,9 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) {
}
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) {
//! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4);
"AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -52,10 +51,9 @@ TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) {
}
TEST_F(AARCH64, MATRIX_MUL_F16_MK8) {
//! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul::check_matrix_mul(
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4);
"AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1);
}
#endif
......@@ -116,10 +114,9 @@ TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) {
}
TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) {
//! nbase should be 4 in order to test the last rest 4 in N dim
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
handle(), "AARCH64_INT16X16X32_MK8_8X8",
param::MatrixMul::Format::MK8, 4);
param::MatrixMul::Format::MK8, 1);
}
//! FIXME: need to add tests of GEMV and QUINT8
......
......@@ -26,7 +26,7 @@ TEST_F(ARMV7, MATRIX_MUL) {
TEST_F(ARMV7, MATRIX_MUL_MK4) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4);
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}
TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) {
......@@ -66,7 +66,7 @@ TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) {
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) {
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{},
handle(), "ARMV7_INT16X16X32_MK8_4X8",
param::MatrixMul::Format::MK8, 4);
param::MatrixMul::Format::MK8, 1);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -78,7 +78,7 @@ TEST_F(ARMV7, MATRIX_MUL_FP16) {
TEST_F(ARMV7, MATRIX_MUL_F16_MK8) {
matrix_mul::check_matrix_mul(
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(),
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4);
"AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册