提交 f99c34c8 编写于 作者: T TianXiaogang 提交者: yiicy

add winograd f23 implement (#2584)

上级 fbb0d3b5
......@@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor,
int channel_size);
// new winograd
void weight_trans_c4(
void weight_trans_c4_8x8(
float* dest, const float* src, int ic, int oc, void* workspace);
void weight_trans_c4_4x4(
float* dest, const float* src, int ic, int oc, void* workspace);
void conv_compute_6x6_3x3(const float* input,
float* output,
......@@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_compute_2x2_3x3(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_compute_2x2_3x3_small(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M,
}
}
}
void sgemm_prepack_c4_small(int M,
int N,
int K,
......@@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M,
}
}
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
ARMContext* ctx) {
const int m_round = (M + 3) / 4 * 4;
const int k_round = (K + 3) / 4 * 4;
const int mloop = m_round >> 2;
const int lda = 4 * k_round;
const int ldb_byte = 4 * N * sizeof(float);
const int kcnt = k_round >> 2;
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int m = 0; m < mloop; ++m) {
const float* b = B;
int n = N;
#ifdef __aarch64__
for (; n > 7; n -= 8) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
// clang-format off
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
/* load a2, a3 */
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v16.4s, v1.s[0] \n"
"fmul v10.4s, v16.4s, v2.s[0] \n"
"fmul v11.4s, v16.4s, v3.s[0] \n"
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32 \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"sub %[b], %[b], #128 \n"
"fmul v12.4s, v16.4s, v4.s[0] \n"
"fmul v13.4s, v16.4s, v5.s[0] \n"
"fmul v14.4s, v16.4s, v6.s[0] \n"
"fmul v15.4s, v16.4s, v7.s[0] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v12.4s, v17.4s, v4.s[1] \n"
"fmla v13.4s, v17.4s, v5.s[1] \n"
"fmla v14.4s, v17.4s, v6.s[1] \n"
"fmla v15.4s, v17.4s, v7.s[1] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v12.4s, v18.4s, v4.s[2] \n"
"fmla v13.4s, v18.4s, v5.s[2] \n"
"fmla v14.4s, v18.4s, v6.s[2] \n"
"fmla v15.4s, v18.4s, v7.s[2] \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"fmla v12.4s, v19.4s, v4.s[3] \n"
"fmla v13.4s, v19.4s, v5.s[3] \n"
"fmla v14.4s, v19.4s, v6.s[3] \n"
"fmla v15.4s, v19.4s, v7.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"prfm pldl1keep, [%[b]] \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32 \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"sub %[b], %[b], #128 \n"
"fmla v12.4s, v16.4s, v4.s[0] \n"
"fmla v13.4s, v16.4s, v5.s[0] \n"
"fmla v14.4s, v16.4s, v6.s[0] \n"
"fmla v15.4s, v16.4s, v7.s[0] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v12.4s, v17.4s, v4.s[1] \n"
"fmla v13.4s, v17.4s, v5.s[1] \n"
"fmla v14.4s, v17.4s, v6.s[1] \n"
"fmla v15.4s, v17.4s, v7.s[1] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v12.4s, v18.4s, v4.s[2] \n"
"fmla v13.4s, v18.4s, v5.s[2] \n"
"fmla v14.4s, v18.4s, v6.s[2] \n"
"fmla v15.4s, v18.4s, v7.s[2] \n"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"fmla v12.4s, v19.4s, v4.s[3] \n"
"fmla v13.4s, v19.4s, v5.s[3] \n"
"fmla v14.4s, v19.4s, v6.s[3] \n"
"fmla v15.4s, v19.4s, v7.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 8;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v16.4s, v1.s[0] \n"
"fmul v10.4s, v16.4s, v2.s[0] \n"
"fmul v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #64 \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[b]], #32 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v16.4s, v1.s[0] \n"
"fmla v10.4s, v16.4s, v2.s[0] \n"
"fmla v11.4s, v16.4s, v3.s[0] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #64 \n"
"fmla v8.4s, v17.4s, v0.s[1] \n"
"fmla v9.4s, v17.4s, v1.s[1] \n"
"fmla v10.4s, v17.4s, v2.s[1] \n"
"fmla v11.4s, v17.4s, v3.s[1] \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v18.4s, v1.s[2] \n"
"fmla v10.4s, v18.4s, v2.s[2] \n"
"fmla v11.4s, v18.4s, v3.s[2] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"fmla v8.4s, v19.4s, v0.s[3] \n"
"fmla v9.4s, v19.4s, v1.s[3] \n"
"fmla v10.4s, v19.4s, v2.s[3] \n"
"fmla v11.4s, v19.4s, v3.s[3] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v1", "v2", "v3", "v8", "v9",
"v10", "v11", "v16", "v17", "v18",
"v19", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16 \n"
"fmul v8.4s, v16.4s, v0.s[0] \n"
"fmul v9.4s, v17.4s, v0.s[1] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v19.4s, v0.s[3] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"beq 2f \n"
"1:\n"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16 \n"
"fmla v8.4s, v16.4s, v0.s[0] \n"
"fmla v9.4s, v17.4s, v0.s[1] \n"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32 \n"
"sub %[b], %[b], #16 \n"
"subs %w[cnt], %w[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"fmla v8.4s, v18.4s, v0.s[2] \n"
"fmla v9.4s, v19.4s, v0.s[3] \n"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32 \n"
"bne 1b \n"
"fadd v8.4s, v8.4s, v9.4s \n"
"2:\n"
"st1 {v8.4s}, [%[c]], #16 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte),
[vzero] "w" (vzero)
: "v0", "v8", "v9", "v16", "v17",
"v18", "v19", "cc", "memory"
);
b += 4;
}
#else
for (; n > 7; n -= 8) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
// clang-format off
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vld1.32 {d0-d3}, [%[b]]! \n"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q8, q4, d0[0] \n"
"vmul.f32 q9, q4, d2[0] \n"
"vmul.f32 q10, q4, d4[0] \n"
"vmul.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q12, q4, d0[0] \n"
"vmul.f32 q13, q4, d2[0] \n"
"vmul.f32 q14, q4, d4[0] \n"
"vmul.f32 q15, q4, d6[0] \n"
"sub %[b], %[b], #128 \n"
"vmla.f32 q12, q5, d0[1] \n"
"vmla.f32 q13, q5, d2[1] \n"
"vmla.f32 q14, q5, d4[1] \n"
"vmla.f32 q15, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q12, q6, d1[0] \n"
"vmla.f32 q13, q6, d3[0] \n"
"vmla.f32 q14, q6, d5[0] \n"
"vmla.f32 q15, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q12, q7, d1[1] \n"
"vmla.f32 q13, q7, d3[1] \n"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q14, q7, d5[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"beq 2f \n"
"1:\n"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q4, d2[0] \n"
"vmla.f32 q10, q4, d4[0] \n"
"vmla.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]! \n"
"pld [%[b]] \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
"pld [%[b], #64] \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q12, q4, d0[0] \n"
"vmla.f32 q13, q4, d2[0] \n"
"vmla.f32 q14, q4, d4[0] \n"
"vmla.f32 q15, q4, d6[0] \n"
"sub %[b], %[b], #128 \n"
"vmla.f32 q12, q5, d0[1] \n"
"vmla.f32 q13, q5, d2[1] \n"
"vmla.f32 q14, q5, d4[1] \n"
"vmla.f32 q15, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q12, q6, d1[0] \n"
"vmla.f32 q13, q6, d3[0] \n"
"vmla.f32 q14, q6, d5[0] \n"
"vmla.f32 q15, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q12, q7, d1[1] \n"
"vmla.f32 q13, q7, d3[1] \n"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vmla.f32 q14, q7, d5[1] \n"
"vmla.f32 q15, q7, d7[1] \n"
"bne 1b \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]! \n"
"vst1.32 {d20-d23}, [%[c]]! \n"
"vst1.32 {d24-d27}, [%[c]]! \n"
"vst1.32 {d28-d31}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "cc", "memory"
);
b += 4 * 8;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmul.f32 q8, q4, d0[0] \n"
"vmul.f32 q9, q4, d2[0] \n"
"vmul.f32 q10, q4, d4[0] \n"
"vmul.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"sub %[b], %[b], #64 \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"beq 2f \n"
"1:\n"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]! \n"
"vld1.32 {d4-d7}, [%[b]]! \n"
"vmla.f32 q8, q4, d0[0] \n"
"vmla.f32 q9, q4, d2[0] \n"
"vmla.f32 q10, q4, d4[0] \n"
"vmla.f32 q11, q4, d6[0] \n"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!\n"
"sub %[b], %[b], #64 \n"
"vmla.f32 q8, q5, d0[1] \n"
"vmla.f32 q9, q5, d2[1] \n"
"vmla.f32 q10, q5, d4[1] \n"
"vmla.f32 q11, q5, d6[1] \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q8, q6, d1[0] \n"
"vmla.f32 q9, q6, d3[0] \n"
"vmla.f32 q10, q6, d5[0] \n"
"vmla.f32 q11, q6, d7[0] \n"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]! \n"
"vmla.f32 q8, q7, d1[1] \n"
"vmla.f32 q9, q7, d3[1] \n"
"vmla.f32 q10, q7, d5[1] \n"
"vmla.f32 q11, q7, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"bne 1b \n"
"2:\n"
"vst1.32 {d16-d19}, [%[c]]!\n"
"vst1.32 {d20-d23}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4", "q5",
"q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "cc", "memory"
);
b += 4 * 4;
}
for (; n > 0; n--) {
int cnt = kcnt;
const float* a_ptr = A_packed;
const float* b_ptr = b;
asm volatile(
"0:\n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]! \n"
"vmul.f32 q5, q1, d0[0] \n"
"vmul.f32 q6, q2, d0[1] \n"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]! \n"
"sub %[b], %[b], #16 \n"
"subs %[cnt], %[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q5, q3, d1[0] \n"
"vmla.f32 q6, q4, d1[1] \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"beq 2f \n"
"1:\n"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]! \n"
"vmla.f32 q5, q1, d0[0] \n"
"vmla.f32 q6, q2, d0[1] \n"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]! \n"
"sub %[b], %[b], #16 \n"
"subs %[cnt], %[cnt], #1 \n"
"add %[b], %[b], %[ldb] \n"
"vmla.f32 q5, q3, d1[0] \n"
"vmla.f32 q6, q4, d1[1] \n"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]! \n"
"bne 1b \n"
"vadd.f32 q5, q5, q6 \n"
"2:\n"
"vst1.32 {d10-d11}, [%[c]]!\n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "cc", "memory"
);
// clang-format on
b += 4;
}
#endif
A_packed += lda;
}
}
void sgemm_prepack_c4(int M,
int N,
int K,
......
......@@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M,
bool has_bias,
bool has_relu,
ARMContext* ctx);
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -68,19 +68,9 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
bool use_winograd =
(threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 &&
pads_equal) ||
(oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal);
if (use_winograd) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
} else {
/// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv";
}
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl
......
......@@ -43,16 +43,21 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
if (last_kernel_is_c4_ == 1) {
choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_kernel_is_c4_ = 1;
last_function_ = 0;
} else {
wino_iw = 8;
if (last_function_ == 1) {
return;
}
last_function_ = 1;
}
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
......@@ -61,61 +66,24 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads;
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad});
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
lite::arm::math::weight_trans_c4(
if (!choose_small_) {
lite::arm::math::weight_trans_c4_8x8(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
free(trans_tmp_ptr);
} else {
if (last_kernel_is_c4_ == 0) {
return;
}
last_kernel_is_c4_ = 0;
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
const int m_wino = oc;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
lite::arm::math::weight_trans_c4_4x4(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
}
free(trans_tmp_ptr);
free(weights_wino);
}
last_shape_ = x_dims;
}
......@@ -145,14 +113,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int ow = o_dims[3];
int oc = o_dims[1];
int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
int threads = ctx.threads();
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
if (!choose_small_) {
lite::arm::math::conv_compute_6x6_3x3(i_data,
o_data,
bs,
......@@ -167,7 +128,11 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
param,
&ctx);
} else {
lite::arm::math::conv_winograd3x3(i_data,
int tile_block = 8;
int block_count =
(((ow + 1) / 2) * ((oh + 1) / 2) + tile_block - 1) / tile_block;
if (block_count != 1) {
lite::arm::math::conv_compute_2x2_3x3(i_data,
o_data,
bs,
oc,
......@@ -180,6 +145,21 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
b_data,
param,
&ctx);
} else {
lite::arm::math::conv_compute_2x2_3x3_small(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
}
}
}
......
......@@ -40,7 +40,9 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
Tensor weights_;
DDim last_shape_;
int workspace_size_{0};
int last_kernel_is_c4_{-1};
int last_function_{-1};
bool choose_small_{false};
int wino_iw{8};
};
} // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册