提交 bd12a37d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5182 optimization for winograd matmul

Merge pull request !5182 from lixian/master
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64OptRemain
#ifndef __APPLE__
.type MatmulFloatNeon64OptRemain, %function
#endif
// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
// int row, int col, size_t stride)
// x0: a
// x1: b
// x2: c
// x3: depth
// x4: row
// x5: col
// x6: stride
// only for winograd
MatmulFloatNeon64OptRemain:
mov x18, #32 // sizeof(float) * 8
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x18, #4
mul x8, x5, x6
mov x11, #8
mul x11, x11, x6
mul x8, x8, x18
mul x11, x11, x18
cmp x4, #4
ble LoopH4
LoopH8:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW8:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
LoopD8:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
fmla v24.4s, v3.4s, v1.s[0]
fmla v26.4s, v3.4s, v1.s[1]
fmla v28.4s, v3.4s, v1.s[2]
fmla v30.4s, v3.4s, v1.s[3]
fmla v25.4s, v4.4s, v1.s[0]
fmla v27.4s, v4.4s, v1.s[1]
fmla v29.4s, v4.4s, v1.s[2]
fmla v31.4s, v4.4s, v1.s[3]
subs w13, w13, #1
bgt LoopD8
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8
subs x10, x10, #8 // lhs row - 8
bgt LoopW8
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH8
ret
LoopH4:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW4:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
LoopD4:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
subs x13, x13, #1
bgt LoopD4
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
subs x10, x10, #4 // lhs row - 4
bgt LoopW4
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH4
ret
#endif
......@@ -303,7 +303,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
C12NUM, oc8 * C8NUM, input_unit_square, 2);
cal_num, oc8 * C8NUM, input_unit_square, 2);
}
// step 4 : output transform
......@@ -489,9 +489,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2);
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
......
......@@ -386,7 +386,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * row;
size_t ai = src_r_offset + d * C12NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
......@@ -403,8 +403,12 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
if (out_type == 2 && row <= 8) {
MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
} else {
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif
......
......@@ -39,6 +39,7 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#endif
#ifdef __cplusplus
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册