提交 0f4e6d73 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5004 [MS][LITE][Develop]optimize fp32 matmul kernel

Merge pull request !5004 from lixian/master
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64Opt
#ifndef __APPLE__
.type MatmulFloatNeon64Opt, %function
#endif
// A: LM [row_8 * depth] col_8_major
// B: RM [depth * col_8] row_8_major
// C: A*B [row_8 * col_8] col_8x8_major
// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8]
///////////////////////////////////////////////////////////////////////////////
//CommLoopMul RM 1x8 block
// /-----------------------------------------\
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LM 8x1 block
// /---------------------\ /-----------------------------------------\
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
// | ... | | ... ... |
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
// | ... | | ... ... |
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
///////////////////////////////////////////////////////////////////////////////
//OptLoopMul4 RM 4x8 block
// /--------------------------------------------\
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]|
// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]|
// \--------------------------------------------/
// LM 8x4 block
// /---------------------------------\ /--------------------------------------------\
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
// | ... ... ... ... | | ... ... |
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
// | ... ... ... ... | | ... ... |
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
// \---------------------------------/ \--------------------------------------------/
// accumulators 8x8 block
/////////////////////////////////////////////////////////////////////////////////
//
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, int stride, bool write_nhwc)
// x0: a
// x1: b
// x2: c
// x3: bias
// w4: act_type
// w5: depth
// w6: row
// w7: col
// w17: stride
// w13: writeC8
MatmulFloatNeon64Opt:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov w18, #32 // sizeof(float) * 8
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x11, x3 // bias flag
mov x18, #4
ldr x17, [sp]
mul x17, x17, x18
L1:
mov w10, w6 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
L2:
mov x16, x1 // reload rhs ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v8.4s, wzr
dup v9.4s, wzr
dup v10.4s, wzr
dup v11.4s, wzr
dup v12.4s, wzr
dup v13.4s, wzr
dup v14.4s, wzr
dup v15.4s, wzr
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
LoopStart:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v8.4s, v3.4s, v0.s[0]
fmla v10.4s, v3.4s, v0.s[1]
fmla v12.4s, v3.4s, v0.s[2]
fmla v14.4s, v3.4s, v0.s[3]
fmla v9.4s, v4.4s, v0.s[0]
fmla v11.4s, v4.4s, v0.s[1]
fmla v13.4s, v4.4s, v0.s[2]
fmla v15.4s, v4.4s, v0.s[3]
subs w13, w13, #1
beq LoopEnd
Loop:
ld1 {v0.4s}, [x12], #16
fmla v16.4s, v3.4s, v1.s[0]
fmla v18.4s, v3.4s, v1.s[1]
fmla v20.4s, v3.4s, v1.s[2]
fmla v22.4s, v3.4s, v1.s[3]
fmla v17.4s, v4.4s, v1.s[0]
fmla v19.4s, v4.4s, v1.s[1]
fmla v21.4s, v4.4s, v1.s[2]
fmla v23.4s, v4.4s, v1.s[3]
ld1 {v1.4s}, [x12], #16
fmla v24.4s, v3.4s, v2.s[0]
fmla v26.4s, v3.4s, v2.s[1]
fmla v28.4s, v3.4s, v2.s[2]
fmla v30.4s, v3.4s, v2.s[3]
ld1 {v3.4s}, [x16], #16
fmla v25.4s, v4.4s, v2.s[0]
fmla v27.4s, v4.4s, v2.s[1]
fmla v29.4s, v4.4s, v2.s[2]
fmla v31.4s, v4.4s, v2.s[3]
ld1 {v4.4s}, [x16], #16
fmla v8.4s, v3.4s, v0.s[0]
fmla v10.4s, v3.4s, v0.s[1]
fmla v12.4s, v3.4s, v0.s[2]
fmla v14.4s, v3.4s, v0.s[3]
ld1 {v2.4s}, [x12], #16
fmla v9.4s, v4.4s, v0.s[0]
fmla v11.4s, v4.4s, v0.s[1]
fmla v13.4s, v4.4s, v0.s[2]
fmla v15.4s, v4.4s, v0.s[3]
subs w13, w13, #1
bgt Loop
LoopEnd:
fmla v16.4s, v3.4s, v1.s[0]
fmla v18.4s, v3.4s, v1.s[1]
fmla v20.4s, v3.4s, v1.s[2]
fmla v22.4s, v3.4s, v1.s[3]
fmla v17.4s, v4.4s, v1.s[0]
fmla v19.4s, v4.4s, v1.s[1]
fmla v21.4s, v4.4s, v1.s[2]
fmla v23.4s, v4.4s, v1.s[3]
fmla v24.4s, v3.4s, v2.s[0]
fmla v26.4s, v3.4s, v2.s[1]
fmla v28.4s, v3.4s, v2.s[2]
fmla v30.4s, v3.4s, v2.s[3]
fmla v25.4s, v4.4s, v2.s[0]
fmla v27.4s, v4.4s, v2.s[1]
fmla v29.4s, v4.4s, v2.s[2]
fmla v31.4s, v4.4s, v2.s[3]
Bias:
cbz x11, Activation
ld1 {v0.4s}, [x14], #16
ld1 {v1.4s}, [x14], #16
fadd v8.4s, v8.4s, v0.4s
fadd v9.4s, v9.4s, v1.4s
fadd v10.4s, v10.4s, v0.4s
fadd v11.4s, v11.4s, v1.4s
fadd v12.4s, v12.4s, v0.4s
fadd v13.4s, v13.4s, v1.4s
fadd v14.4s, v14.4s, v0.4s
fadd v15.4s, v15.4s, v1.4s
fadd v16.4s, v16.4s, v0.4s
fadd v17.4s, v17.4s, v1.4s
fadd v18.4s, v18.4s, v0.4s
fadd v19.4s, v19.4s, v1.4s
fadd v20.4s, v20.4s, v0.4s
fadd v21.4s, v21.4s, v1.4s
fadd v22.4s, v22.4s, v0.4s
fadd v23.4s, v23.4s, v1.4s
fadd v24.4s, v24.4s, v0.4s
fadd v25.4s, v25.4s, v1.4s
fadd v26.4s, v26.4s, v0.4s
fadd v27.4s, v27.4s, v1.4s
fadd v28.4s, v28.4s, v0.4s
fadd v29.4s, v29.4s, v1.4s
fadd v30.4s, v30.4s, v0.4s
fadd v31.4s, v31.4s, v1.4s
Activation:
cmp w4, #2
beq Relu6
cmp w4, #1
beq Relu
b Write
Relu6:
mov w8, #6
dup v2.4s, w8
scvtf v2.4s, v2.4s
fmin v8.4s, v8.4s, v2.4s
fmin v9.4s, v9.4s, v2.4s
fmin v10.4s, v10.4s, v2.4s
fmin v11.4s, v11.4s, v2.4s
fmin v12.4s, v12.4s, v2.4s
fmin v13.4s, v13.4s, v2.4s
fmin v14.4s, v14.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s
fmin v16.4s, v16.4s, v2.4s
fmin v17.4s, v17.4s, v2.4s
fmin v18.4s, v18.4s, v2.4s
fmin v19.4s, v19.4s, v2.4s
fmin v20.4s, v20.4s, v2.4s
fmin v21.4s, v21.4s, v2.4s
fmin v22.4s, v22.4s, v2.4s
fmin v23.4s, v23.4s, v2.4s
fmin v24.4s, v24.4s, v2.4s
fmin v25.4s, v25.4s, v2.4s
fmin v26.4s, v26.4s, v2.4s
fmin v27.4s, v27.4s, v2.4s
fmin v28.4s, v28.4s, v2.4s
fmin v29.4s, v29.4s, v2.4s
fmin v30.4s, v30.4s, v2.4s
fmin v31.4s, v31.4s, v2.4s
Relu:
dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s
fmax v9.4s, v9.4s, v3.4s
fmax v10.4s, v10.4s, v3.4s
fmax v11.4s, v11.4s, v3.4s
fmax v12.4s, v12.4s, v3.4s
fmax v13.4s, v13.4s, v3.4s
fmax v14.4s, v14.4s, v3.4s
fmax v15.4s, v15.4s, v3.4s
fmax v16.4s, v16.4s, v3.4s
fmax v17.4s, v17.4s, v3.4s
fmax v18.4s, v18.4s, v3.4s
fmax v19.4s, v19.4s, v3.4s
fmax v20.4s, v20.4s, v3.4s
fmax v21.4s, v21.4s, v3.4s
fmax v22.4s, v22.4s, v3.4s
fmax v23.4s, v23.4s, v3.4s
fmax v24.4s, v24.4s, v3.4s
fmax v25.4s, v25.4s, v3.4s
fmax v26.4s, v26.4s, v3.4s
fmax v27.4s, v27.4s, v3.4s
fmax v28.4s, v28.4s, v3.4s
fmax v29.4s, v29.4s, v3.4s
fmax v30.4s, v30.4s, v3.4s
fmax v31.4s, v31.4s, v3.4s
Write:
ldrb w13, [sp, #8]
cbz w13, WriteC8
cmp w7, #1
beq Write1
cmp w7, #2
beq Write2
cmp w7, #3
beq Write3
cmp w7, #4
beq Write4
cmp w7, #5
beq Write5
cmp w7, #6
beq Write6
cmp w7, #7
beq Write7
b Write8
Write1:
str s8, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
str s10, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
str s12, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
str s14, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
str s16, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
str s18, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
str s20, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
str s22, [x18]
cmp w10, #8
beq WriteEnd
add x18, x18, x17
str s24, [x18]
cmp w10, #9
beq WriteEnd
add x18, x18, x17
str s26, [x18]
cmp w10, #10
beq WriteEnd
add x18, x18, x17
str s28, [x18]
cmp w10, #11
beq WriteEnd
add x18, x18, x17
str s30, [x18]
add x18, x18, x17
b WriteEnd
Write2:
dup s9, v8.s[1]
stp s8, s9, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
dup s11, v10.s[1]
stp s10, s11, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
dup s13, v12.s[1]
stp s12, s13, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
dup s15, v14.s[1]
stp s14, s15, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
dup s17, v16.s[1]
stp s16, s17, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
dup s19, v18.s[1]
stp s18, s19, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
dup s21, v20.s[1]
stp s20, s21, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
dup s23, v22.s[1]
stp s22, s23, [x18]
cmp w10, #8
beq WriteEnd
add x18, x18, x17
dup s25, v24.s[1]
stp s24, s25, [x18]
cmp w10, #9
beq WriteEnd
add x18, x18, x17
dup s27, v26.s[1]
stp s26, s27, [x18]
cmp w10, #10
beq WriteEnd
add x18, x18, x17
dup s29, v28.s[1]
stp s28, s29, [x18]
cmp w10, #11
beq WriteEnd
add x18, x18, x17
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
b WriteEnd
Write3:
add x13, x18, #8
dup s9, v8.s[1]
stp s8, s9, [x18]
add x18, x18, x17
st1 {v8.s}[2], [x13], x17
cmp w10, #1
beq WriteEnd
dup s11, v10.s[1]
stp s10, s11, [x18]
add x18, x18, x17
st1 {v10.s}[2], [x13], x17
cmp w10, #2
beq WriteEnd
dup s13, v12.s[1]
stp s12, s13, [x18]
add x18, x18, x17
st1 {v12.s}[2], [x13], x17
cmp w10, #3
beq WriteEnd
dup s15, v14.s[1]
stp s14, s15, [x18]
add x18, x18, x17
st1 {v14.s}[2], [x13], x17
cmp w10, #4
beq WriteEnd
dup s17, v16.s[1]
stp s16, s17, [x18]
add x18, x18, x17
st1 {v16.s}[2], [x13], x17
cmp w10, #5
beq WriteEnd
dup s19, v18.s[1]
stp s18, s19, [x18]
add x18, x18, x17
st1 {v18.s}[2], [x13], x17
cmp w10, #6
beq WriteEnd
dup s21, v20.s[1]
stp s20, s21, [x18]
add x18, x18, x17
st1 {v20.s}[2], [x13], x17
cmp w10, #7
beq WriteEnd
dup s23, v22.s[1]
stp s22, s23, [x18]
add x18, x18, x17
st1 {v22.s}[2], [x13], x17
cmp w10, #8
beq WriteEnd
dup s25, v24.s[1]
stp s24, s25, [x18]
add x18, x18, x17
st1 {v24.s}[2], [x13], x17
cmp w10, #9
beq WriteEnd
dup s27, v26.s[1]
stp s26, s27, [x18]
add x18, x18, x17
st1 {v26.s}[2], [x13], x17
cmp w10, #10
beq WriteEnd
dup s29, v28.s[1]
stp s28, s29, [x18]
add x18, x18, x17
st1 {v28.s}[2], [x13], x17
cmp w10, #11
beq WriteEnd
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
st1 {v30.s}[2], [x13]
b WriteEnd
Write4:
st1 {v8.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s}, [x18], x17
b WriteEnd
Write5:
add x13, x18, #16
st1 {v8.4s}, [x18], x17
str s9, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v10.4s}, [x18], x17
str s11, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v12.4s}, [x18], x17
str s13, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v14.4s}, [x18], x17
str s15, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v16.4s}, [x18], x17
str s17, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
str s19, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
str s21, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
str s23, [x13]
cmp w10, #8
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
str s25, [x13]
cmp w10, #9
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
str s27, [x13]
cmp w10, #10
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
str s29, [x13]
cmp w10, #11
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
str s31, [x13]
b WriteEnd
Write6:
add x13, x18, #16
st1 {v8.4s}, [x18], x17
dup s8, v9.s[1]
stp s9, s8, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v10.4s}, [x18], x17
dup s10, v11.s[1]
stp s11, s10, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v12.4s}, [x18], x17
dup s12, v13.s[1]
stp s13, s12, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v14.4s}, [x18], x17
dup s14, v15.s[1]
stp s15, s14, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
cmp w10, #8
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
cmp w10, #9
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
cmp w10, #10
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
cmp w10, #11
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
b WriteEnd
Write7:
add x13, x18, #16
add x16, x18, #24
st1 {v8.4s}, [x18], x17
dup s8, v9.s[1]
stp s9, s8, [x13]
add x13, x13, x17
st1 {v9.s}[2], [x16], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s}, [x18], x17
dup s10, v11.s[1]
stp s11, s10, [x13]
add x13, x13, x17
st1 {v11.s}[2], [x16], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s}, [x18], x17
dup s12, v13.s[1]
stp s13, s12, [x13]
add x13, x13, x17
st1 {v13.s}[2], [x16], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s}, [x18], x17
dup s14, v15.s[1]
stp s15, s14, [x13]
add x13, x13, x17
st1 {v15.s}[2], [x16], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
add x13, x13, x17
st1 {v17.s}[2], [x16], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
add x13, x13, x17
st1 {v19.s}[2], [x16], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
add x13, x13, x17
st1 {v21.s}[2], [x16], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
add x13, x13, x17
st1 {v23.s}[2], [x16], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
add x13, x13, x17
st1 {v25.s}[2], [x16], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
add x13, x13, x17
st1 {v27.s}[2], [x16], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
add x13, x13, x17
st1 {v29.s}[2], [x16], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
add x13, x13, x17
st1 {v31.s}[2], [x16], x17
b WriteEnd
WriteC8:
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x2], #64
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
b WriteEnd
Write8:
st1 {v8.4s, v9.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s, v11.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s, v13.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s, v15.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s, v17.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s, v19.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s, v21.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s, v23.4s}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s, v25.4s}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s, v27.4s}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s, v29.4s}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s, v31.4s}, [x18], x17
WriteEnd:
subs w10, w10, #12 // lhs row - 12
bgt L2
End2:
subs w7, w7, #8 // rhs col - 8
add x1, x1, x15 // rhs ptr + stride
add x3, x3, #32 // bias ptr + stride
ldrb w13, [sp, #8]
cbz w13, NoDstStep
add x2, x2, #32 // dst ptr + stride
NoDstStep:
bgt L1
End1:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif
...@@ -28,6 +28,108 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { ...@@ -28,6 +28,108 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
return; return;
} }
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row12; ri += C12NUM) {
size_t ci = 0;
for (; ci < col4; ci += C4NUM) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C12NUM;
/* 12x4 row-major to col-major */
#ifdef ENABLE_ARM64
size_t stride = col * sizeof(float);
asm volatile(
"mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n"
"ld1 {v0.4s}, [x10], %[stride]\n"
"ld1 {v1.4s}, [x10], %[stride]\n"
"ld1 {v2.4s}, [x10], %[stride]\n"
"ld1 {v3.4s}, [x10], %[stride]\n"
"ld1 {v4.4s}, [x10], %[stride]\n"
"ld1 {v5.4s}, [x10], %[stride]\n"
"ld1 {v6.4s}, [x10], %[stride]\n"
"ld1 {v7.4s}, [x10], %[stride]\n"
"zip1 v12.4s, v0.4s, v1.4s\n"
"zip2 v13.4s, v0.4s, v1.4s\n"
"zip1 v14.4s, v2.4s, v3.4s\n"
"zip2 v15.4s, v2.4s, v3.4s\n"
"ld1 {v8.4s}, [x10], %[stride]\n"
"ld1 {v9.4s}, [x10], %[stride]\n"
"ld1 {v10.4s}, [x10], %[stride]\n"
"ld1 {v11.4s}, [x10], %[stride]\n"
"zip1 v16.4s, v4.4s, v5.4s\n"
"zip2 v17.4s, v4.4s, v5.4s\n"
"zip1 v18.4s, v6.4s, v7.4s\n"
"zip2 v19.4s, v6.4s, v7.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v23.2d, v12.2d, v14.2d\n"
"trn1 v26.2d, v13.2d, v15.2d\n"
"trn2 v29.2d, v13.2d, v15.2d\n"
"trn1 v21.2d, v16.2d, v18.2d\n"
"trn2 v24.2d, v16.2d, v18.2d\n"
"trn1 v27.2d, v17.2d, v19.2d\n"
"trn2 v30.2d, v17.2d, v19.2d\n"
"zip1 v12.4s, v8.4s, v9.4s\n"
"zip2 v13.4s, v8.4s, v9.4s\n"
"zip1 v14.4s, v10.4s, v11.4s\n"
"zip2 v15.4s, v10.4s, v11.4s\n"
"trn1 v22.2d, v12.2d, v14.2d\n"
"trn2 v25.2d, v12.2d, v14.2d\n"
"trn1 v28.2d, v13.2d, v15.2d\n"
"trn2 v31.2d, v13.2d, v15.2d\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C12NUM;
for (size_t i = 0; i < C12NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C12NUM * col;
dst_r += C12NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C12NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
return;
}
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C8NUM * C8NUM; size_t row8 = row / C8NUM * C8NUM;
size_t col4 = col / C4NUM * C4NUM; size_t col4 = col / C4NUM * C4NUM;
...@@ -267,6 +369,31 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac ...@@ -267,6 +369,31 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
return; return;
} }
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / 12, r12mod = r % 12;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) { int stride, bool write_nhwc) {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
...@@ -275,3 +402,12 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType ...@@ -275,3 +402,12 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif #endif
} }
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif
}
...@@ -28,12 +28,17 @@ extern "C" { ...@@ -28,12 +28,17 @@ extern "C" {
#endif #endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc); int stride, bool write_nhwc);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row,
int col, int stride, bool write_nhwc);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc); int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
} }
......
...@@ -31,6 +31,7 @@ typedef struct MatMulParameter { ...@@ -31,6 +31,7 @@ typedef struct MatMulParameter {
int row_; int row_;
int col_; int col_;
int row_8_; int row_8_;
int row_12_;
int row_16_; int row_16_;
int col_8_; int col_8_;
int deep_; int deep_;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#define C4NUM 4 #define C4NUM 4
#define C8NUM 8 #define C8NUM 8
#define C12NUM 12
#define C16NUM 16 #define C16NUM 16
#define BLOCK 4 #define BLOCK 4
#define TILE_NUM 8 #define TILE_NUM 8
......
...@@ -59,7 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { ...@@ -59,7 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = conv_param_->output_channel_; matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_; matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No; matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_; matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
...@@ -100,12 +100,12 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { ...@@ -100,12 +100,12 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float))); pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
if (pack_input_ == nullptr) { if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float)); memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float));
return RET_OK; return RET_OK;
} }
...@@ -118,7 +118,7 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { ...@@ -118,7 +118,7 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) {
input_ptr_ = src_input; input_ptr_ = src_input;
} }
RowMajor2Col8Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
return; return;
} }
...@@ -143,7 +143,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { ...@@ -143,7 +143,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id; auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, true); matmul_param_->row_, cur_oc, matmul_param_->col_, true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册