diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..c1d3f4498c03806b430b3542e88f2350a6557e72 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S @@ -0,0 +1,784 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulFloatNeon64Opt +#ifndef __APPLE__ + .type MatmulFloatNeon64Opt, %function +#endif + +// A: LM [row_8 * depth] col_8_major +// B: RM [depth * col_8] row_8_major +// C: A*B [row_8 * col_8] col_8x8_major +// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8] +/////////////////////////////////////////////////////////////////////////////// +//CommLoopMul RM 1x8 block +// /-----------------------------------------\ +// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| +// \-----------------------------------------/ +// LM 8x1 block +// /---------------------\ /-----------------------------------------\ +// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]| +// | ... | | ... ... | +// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]| +// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]| +// | ... | | ... ... | +// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]| +// \---------------------/ \-----------------------------------------/ +// accumulators 8x8 block +// +/////////////////////////////////////////////////////////////////////////////// +//OptLoopMul4 RM 4x8 block +// /--------------------------------------------\ +// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] | +// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]| +// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]| +// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]| +// \--------------------------------------------/ +// LM 8x4 block +// /---------------------------------\ /--------------------------------------------\ +// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] | +// | ... ... ... ... | | ... ... | +// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] | +// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] | +// | ... ... ... ... | | ... ... | +// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] | +// \---------------------------------/ \--------------------------------------------/ +// accumulators 8x8 block +///////////////////////////////////////////////////////////////////////////////// +// +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, int stride, bool write_nhwc) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: writeC8 + +MatmulFloatNeon64Opt: + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + mov w18, #32 // sizeof(float) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x11, x3 // bias flag + mov x18, #4 + ldr x17, [sp] + mul x17, x17, x18 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x18, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +LoopStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + beq LoopEnd + +Loop: + ld1 {v0.4s}, [x12], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ld1 {v1.4s}, [x12], #16 + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ld1 {v3.4s}, [x16], #16 + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ld1 {v4.4s}, [x16], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v2.4s}, [x12], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + bgt Loop + +LoopEnd: + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + +Bias: + cbz x11, Activation + ld1 {v0.4s}, [x14], #16 + ld1 {v1.4s}, [x14], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + +Activation: + cmp w4, #2 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + mov w8, #6 + dup v2.4s, w8 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + +Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + +Write: + ldrb w13, [sp, #8] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + str s8, [x18] + cmp w10, #1 + beq WriteEnd + add x18, x18, x17 + str s10, [x18] + cmp w10, #2 + beq WriteEnd + add x18, x18, x17 + str s12, [x18] + cmp w10, #3 + beq WriteEnd + add x18, x18, x17 + str s14, [x18] + cmp w10, #4 + beq WriteEnd + add x18, x18, x17 + str s16, [x18] + cmp w10, #5 + beq WriteEnd + add x18, x18, x17 + str s18, [x18] + cmp w10, #6 + beq WriteEnd + add x18, x18, x17 + str s20, [x18] + cmp w10, #7 + beq WriteEnd + add x18, x18, x17 + str s22, [x18] + cmp w10, #8 + beq WriteEnd + add x18, x18, x17 + str s24, [x18] + cmp w10, #9 + beq WriteEnd + add x18, x18, x17 + str s26, [x18] + cmp w10, #10 + beq WriteEnd + add x18, x18, x17 + str s28, [x18] + cmp w10, #11 + beq WriteEnd + add x18, x18, x17 + str s30, [x18] + add x18, x18, x17 + b WriteEnd +Write2: + dup s9, v8.s[1] + stp s8, s9, [x18] + cmp w10, #1 + beq WriteEnd + add x18, x18, x17 + dup s11, v10.s[1] + stp s10, s11, [x18] + cmp w10, #2 + beq WriteEnd + add x18, x18, x17 + dup s13, v12.s[1] + stp s12, s13, [x18] + cmp w10, #3 + beq WriteEnd + add x18, x18, x17 + dup s15, v14.s[1] + stp s14, s15, [x18] + cmp w10, #4 + beq WriteEnd + add x18, x18, x17 + dup s17, v16.s[1] + stp s16, s17, [x18] + cmp w10, #5 + beq WriteEnd + add x18, x18, x17 + dup s19, v18.s[1] + stp s18, s19, [x18] + cmp w10, #6 + beq WriteEnd + add x18, x18, x17 + dup s21, v20.s[1] + stp s20, s21, [x18] + cmp w10, #7 + beq WriteEnd + add x18, x18, x17 + dup s23, v22.s[1] + stp s22, s23, [x18] + cmp w10, #8 + beq WriteEnd + add x18, x18, x17 + dup s25, v24.s[1] + stp s24, s25, [x18] + cmp w10, #9 + beq WriteEnd + add x18, x18, x17 + dup s27, v26.s[1] + stp s26, s27, [x18] + cmp w10, #10 + beq WriteEnd + add x18, x18, x17 + dup s29, v28.s[1] + stp s28, s29, [x18] + cmp w10, #11 + beq WriteEnd + add x18, x18, x17 + dup s31, v30.s[1] + stp s30, s31, [x18] + add x18, x18, x17 + b WriteEnd +Write3: + add x13, x18, #8 + dup s9, v8.s[1] + stp s8, s9, [x18] + add x18, x18, x17 + st1 {v8.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s11, v10.s[1] + stp s10, s11, [x18] + add x18, x18, x17 + st1 {v10.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s13, v12.s[1] + stp s12, s13, [x18] + add x18, x18, x17 + st1 {v12.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s15, v14.s[1] + stp s14, s15, [x18] + add x18, x18, x17 + st1 {v14.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd + dup s17, v16.s[1] + stp s16, s17, [x18] + add x18, x18, x17 + st1 {v16.s}[2], [x13], x17 + cmp w10, #5 + beq WriteEnd + dup s19, v18.s[1] + stp s18, s19, [x18] + add x18, x18, x17 + st1 {v18.s}[2], [x13], x17 + cmp w10, #6 + beq WriteEnd + dup s21, v20.s[1] + stp s20, s21, [x18] + add x18, x18, x17 + st1 {v20.s}[2], [x13], x17 + cmp w10, #7 + beq WriteEnd + dup s23, v22.s[1] + stp s22, s23, [x18] + add x18, x18, x17 + st1 {v22.s}[2], [x13], x17 + cmp w10, #8 + beq WriteEnd + dup s25, v24.s[1] + stp s24, s25, [x18] + add x18, x18, x17 + st1 {v24.s}[2], [x13], x17 + cmp w10, #9 + beq WriteEnd + dup s27, v26.s[1] + stp s26, s27, [x18] + add x18, x18, x17 + st1 {v26.s}[2], [x13], x17 + cmp w10, #10 + beq WriteEnd + dup s29, v28.s[1] + stp s28, s29, [x18] + add x18, x18, x17 + st1 {v28.s}[2], [x13], x17 + cmp w10, #11 + beq WriteEnd + dup s31, v30.s[1] + stp s30, s31, [x18] + add x18, x18, x17 + st1 {v30.s}[2], [x13] + b WriteEnd +Write4: + st1 {v8.4s}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x18], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x18], x17 + b WriteEnd +Write5: + add x13, x18, #16 + st1 {v8.4s}, [x18], x17 + str s9, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x18], x17 + str s11, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x18], x17 + str s13, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x18], x17 + str s15, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x18], x17 + str s17, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x18], x17 + str s19, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x18], x17 + str s21, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x18], x17 + str s23, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x18], x17 + str s25, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x18], x17 + str s27, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x18], x17 + str s29, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x18], x17 + str s31, [x13] + b WriteEnd +Write6: + add x13, x18, #16 + st1 {v8.4s}, [x18], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x18], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x18], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x18], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x18], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x18], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x18], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x18], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x18], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x18], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x18], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x18], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + b WriteEnd +Write7: + add x13, x18, #16 + add x16, x18, #24 + st1 {v8.4s}, [x18], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + add x13, x13, x17 + st1 {v9.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x18], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + add x13, x13, x17 + st1 {v11.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x18], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + add x13, x13, x17 + st1 {v13.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x18], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + add x13, x13, x17 + st1 {v15.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x18], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + add x13, x13, x17 + st1 {v17.s}[2], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x18], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + add x13, x13, x17 + st1 {v19.s}[2], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x18], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + add x13, x13, x17 + st1 {v21.s}[2], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x18], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + add x13, x13, x17 + st1 {v23.s}[2], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x18], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + add x13, x13, x17 + st1 {v25.s}[2], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x18], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + add x13, x13, x17 + st1 {v27.s}[2], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x18], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + add x13, x13, x17 + st1 {v29.s}[2], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x18], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + add x13, x13, x17 + st1 {v31.s}[2], [x16], x17 + b WriteEnd +WriteC8: + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x2], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + b WriteEnd +Write8: + st1 {v8.4s, v9.4s}, [x18], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x18], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x18], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x18], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x18], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x18], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x18], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x18], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x18], x17 + +WriteEnd: + subs w10, w10, #12 // lhs row - 12 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #32 // bias ptr + stride + ldrb w13, [sp, #8] + cbz w13, NoDstStep + add x2, x2, #32 // dst ptr + stride +NoDstStep: + bgt L1 + +End1: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/nnacl/fp32/matmul.c b/mindspore/lite/nnacl/fp32/matmul.c index 03c65ef23f5ddfd12286b06315e485c375f290e7..97ff005637a7add340e1a6095696d27b82e265c5 100644 --- a/mindspore/lite/nnacl/fp32/matmul.c +++ b/mindspore/lite/nnacl/fp32/matmul.c @@ -28,6 +28,108 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { return; } +void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { + size_t row12 = row / C12NUM * C12NUM; + size_t col4 = col / C4NUM * C4NUM; + float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + + /* 12x4 row-major to col-major */ +#ifdef ENABLE_ARM64 + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + return; +} + void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { size_t row8 = row / C8NUM * C8NUM; size_t col4 = col / C4NUM * C4NUM; @@ -267,6 +369,31 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac return; } +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, bool write_nhwc) { + if (write_nhwc) { + /* col8-major * row8-major => col-major */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } + } + return; +} + void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, int stride, bool write_nhwc) { #ifdef ENABLE_ARM64 @@ -275,3 +402,12 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); #endif } + +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, int stride, bool write_nhwc) { +#ifdef ENABLE_ARM64 + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); +#else + MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); +#endif +} diff --git a/mindspore/lite/nnacl/fp32/matmul.h b/mindspore/lite/nnacl/fp32/matmul.h index 7459e426ea288f5c410b629c854c738ec00cc750..4bc74c1b3a3e6db556c78d474dfc790b229d0c73 100644 --- a/mindspore/lite/nnacl/fp32/matmul.h +++ b/mindspore/lite/nnacl/fp32/matmul.h @@ -28,12 +28,17 @@ extern "C" { #endif void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, int stride, bool write_nhwc); +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, + int col, int stride, bool write_nhwc); void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); +void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); #ifdef ENABLE_ARM64 void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, size_t stride, bool write_nhwc); +void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, bool write_nhwc); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 9e290e78410124acb2ccad70155521bd86e467a3..fa377468e5e6b2b78eb813b40eb47eecb098b6cb 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -31,6 +31,7 @@ typedef struct MatMulParameter { int row_; int col_; int row_8_; + int row_12_; int row_16_; int col_8_; int deep_; diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index f7c90bce49d03fb9fb033b84bf6f11771b2c57f7..11450bc48931d116da23fda0b6396d6dde8f2a91 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -23,6 +23,7 @@ #define C4NUM 4 #define C8NUM 8 +#define C12NUM 12 #define C16NUM 16 #define BLOCK 4 #define TILE_NUM 8 diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc index 30b2b6a1e3fb787e8d4347e31589f7dd5a74833b..798488a1bd58e3f7b5e83967bd4549d9ad02b7ff 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc @@ -59,7 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; matmul_param_->col_ = conv_param_->output_channel_; matmul_param_->deep_ = conv_param_->input_channel_; - matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); + matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No; matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_; @@ -100,12 +100,12 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; - pack_input_ = reinterpret_cast(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float))); + pack_input_ = reinterpret_cast(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); if (pack_input_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; return RET_MEMORY_FAILED; } - memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); + memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)); return RET_OK; } @@ -118,7 +118,7 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { input_ptr_ = src_input; } - RowMajor2Col8Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); + RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); return; } @@ -143,7 +143,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast(bias_data_) + thread_stride_ * task_id; - MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, + MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, cur_oc, matmul_param_->col_, true);