From 61647c357b7c6d2d21e7bee7058ee26cfade5a55 Mon Sep 17 00:00:00 2001 From: Xiaoyang LI Date: Wed, 25 Sep 2019 22:41:39 +0800 Subject: [PATCH] add workspace compute funcs for direct conv, test=develop (#2132) --- .../arm/math/conv3x3s1_direct_fp32.cc | 1156 ++++++--------- .../arm/math/conv3x3s2_direct_fp32.cc | 1243 ++++++----------- lite/backends/arm/math/conv_impl.h | 6 + lite/core/device_info.cc | 8 +- lite/core/device_info.h | 2 +- lite/kernels/arm/conv_direct.cc | 25 + lite/kernels/arm/conv_direct.h | 2 + 7 files changed, 924 insertions(+), 1518 deletions(-) diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc index 78f9de1c2a..6a1fa37681 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -26,6 +26,39 @@ namespace lite { namespace arm { namespace math { +const int OUT_C_BLOCK = 4; +const int OUT_H_BLOCK = 2; +const int OUT_W_BLOCK = 4; + +size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx) { + auto dim_in = param.x->dims(); + auto dim_out = param.output->dims(); + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + int ow = dim_out[3]; + int oh = dim_out[2]; + int ic = dim_in[1]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round + 2; + + int hout_r_block = (llc_size - 2 * win_round * ic) / + (win_round * ic + OUT_C_BLOCK * wout_round * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + const int hin_r_block = hout_r_block + 2; + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size); +} + void conv_3x3s1_direct_fp32(const float* i_data, float* o_data, int bs, @@ -44,19 +77,16 @@ void conv_3x3s1_direct_fp32(const float* i_data, const int pad_h = param.paddings[0]; const int pad_w = param.paddings[1]; - const int hout_c_block = 4; - const int hout_r_kernel = 2; - const int wout_block = 4; - const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round + 2; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; int hout_r_block = (l2_size - 2 * win_round * ic) / - (win_round * ic + hout_c_block * wout_round * threads); + (win_round * ic + OUT_C_BLOCK * wout_round * threads); hout_r_block = hout_r_block > oh ? oh : hout_r_block; - hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; - hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; const int hin_r_block = hout_r_block + 2; @@ -67,23 +97,23 @@ void conv_3x3s1_direct_fp32(const float* i_data, int in_len = win_round * ic; int pre_in_size = hin_r_block * in_len; - int pre_out_size = hout_c_block * hout_r_block * wout_round; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; float* pre_din = tmp_work_space; int size_in_channel = win * ih; int size_out_channel = ow * oh; - int w_stride = ic * 9; // kernel_w * kernel_h; - int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + int w_stride = ic * 9; // kernel_w * kernel_h; + int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h * int ws = -pad_w; int we = ws + win_round; int w_loop = wout_round / 4; - int c_remain = oc - (oc / hout_c_block) * hout_c_block; - int c_round_down = (oc / hout_c_block) * hout_c_block; + int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; - int out_row_stride = hout_c_block * wout_round; + int out_row_stride = OUT_C_BLOCK * wout_round; for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -97,7 +127,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, prepack_input_nxw( din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); #pragma omp parallel for num_threads(threads) - for (int c = 0; c < oc - (hout_c_block - 1); c += hout_c_block) { + for (int c = 0; c < oc - (OUT_C_BLOCK - 1); c += OUT_C_BLOCK) { #ifdef ARM_WITH_OMP float* pre_out = pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; @@ -115,9 +145,9 @@ void conv_3x3s1_direct_fp32(const float* i_data, bias_ptr = bias + c; } fill_packed_biasc4( - pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); + pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel); - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_c; const float* inr0 = block_inr0; @@ -148,161 +178,125 @@ void conv_3x3s1_direct_fp32(const float* i_data, const float* r3 = inr3; int cnt = w_loop; + // clang-format off asm volatile( - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, - outr01*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ - "2: \n" /* main loop*/ - /* r0, r1, mul w0, get out r0, r1 */ - "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ - "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ - "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ - "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ - "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ - "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ - "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ - "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ - - /* r0, r1, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ - "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ - "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ - "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ - "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ - "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ - "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ - "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ - - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - - /* r0, r1, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ - "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ - "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ - "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ - "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ - "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ - "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ - - /* r1, r2, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ - "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ - "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ - "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ - "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ - "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ - "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ - - "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ - - /* r1, r2, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ - "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ - "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ - "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ - "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ - "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ - "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ - "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ - - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - - /* r1, r2, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ - "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ - "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ - "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ - "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ - "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ - "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ - - /* r2, r3, mul w6, get out r0, r1 */ - "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ - "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ - "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ - "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ - "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ - "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ - "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ - - "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ - - /* r2, r3, mul w7, get out r0, r1 */ - "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ - "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ - "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ - "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ - "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ - "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ - "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - /* r2, r3, mul w8, get out r0, r1 */ - "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ - "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ - "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ - - "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ - "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ - "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ - "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ - "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ - "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ - "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; + "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00,outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "2: \n" /* main loop*/ + /* r0, r1, mul w0, get out r0, r1 */ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ + "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ + "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ + "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ + "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ + "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ + "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ + /* r0, r1, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ + "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ + "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ + "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ + "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ + "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ + "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + /* r0, r1, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ + "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ + "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ + "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ + "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ + "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ + "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ + /* r1, r2, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ + "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ + "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ + "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ + "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ + "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ + "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ + /* r1, r2, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ + "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ + "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ + "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ + "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ + "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ + "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + /* r1, r2, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ + "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ + "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ + "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ + "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ + "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ + "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ + /* r2, r3, mul w6, get out r0, r1 */ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ + "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ + "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ + "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ + "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ + "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ + "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ + /* r2, r3, mul w7, get out r0, r1 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ + "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ + "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ + "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ + "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ + "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ + "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + /* r2, r3, mul w8, get out r0, r1 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ + "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ + "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3", + "v4","v5","v6","v7","v15","v16", + "v17","v18","v19","v20","v21","v22" + ); + // clang-format on + + wc0 += 9 * OUT_C_BLOCK; inr0 += win_round; inr1 += win_round; inr2 += win_round; @@ -321,273 +315,135 @@ void conv_3x3s1_direct_fp32(const float* i_data, const float* r3 = inr3; int cnt = w_loop; + // clang-format off asm volatile( - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " - "load outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - /* load r0, r1 */ - "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " - "4 float\n" - "vld1.32 {d2}, [%[r0]] @ load r0, " - "2 float\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - /* main loop */ - "0: @ main " - "loop\n" - /* mul r0 with w0, w1, w2, get out r0 */ - "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " - "outr1, w0, w1, c0~c3\n" - "vmla.f32 q8, q5, d0[0] @ w0 * " - "inr00\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " - "outr1, w2, w3, c0~c3\n" - "vmla.f32 q9, q5, d0[1] @ w0 * " - "inr01\n" - "vmla.f32 q10, q5, d1[0] @ w0 * " - "inr02\n" - "vmla.f32 q11, q5, d1[1] @ w0 * " - "inr03\n" - "vld1.32 {d3-d4}, [%[r1]]! @ load r1, " - "4 float\n" - "vmla.f32 q8, q6, d0[1] @ w1 * " - "inr01\n" - "vmla.f32 q9, q6, d1[0] @ w1 * " - "inr02\n" - "vmla.f32 q10, q6, d1[1] @ w1 * " - "inr03\n" - "vmla.f32 q11, q6, d2[0] @ w1 * " - "inr04\n" - "vld1.32 {d5}, [%[r1]] @ load r0, " - "2 float\n" - "vmla.f32 q8, q7, d1[0] @ w2 * " - "inr02\n" - "vmla.f32 q9, q7, d1[1] @ w2 * " - "inr03\n" - "vmla.f32 q10, q7, d2[0] @ w2 * " - "inr04\n" - "vmla.f32 q11, q7, d2[1] @ w2 * " - "inr05\n" - - "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " - "- 32, to start address\n" - - /* mul r1 with w0, w1, w2, get out r1 */ - "vmla.f32 q12, q5, d3[0] @ w0 * " - "inr10\n" - "vmla.f32 q13, q5, d3[1] @ w0 * " - "inr11\n" - "vmla.f32 q14, q5, d4[0] @ w0 * " - "inr12\n" - "vmla.f32 q15, q5, d4[1] @ w0 * " - "inr13\n" - "vmla.f32 q12, q6, d3[1] @ w1 * " - "inr11\n" - "vmla.f32 q13, q6, d4[0] @ w1 * " - "inr12\n" - "vmla.f32 q14, q6, d4[1] @ w1 * " - "inr13\n" - "vmla.f32 q15, q6, d5[0] @ w1 * " - "inr14\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " - "w4, to q5, q6\n" - "vmla.f32 q12, q7, d4[0] @ w2 * " - "inr12\n" - "vmla.f32 q13, q7, d4[1] @ w2 * " - "inr13\n" - "vmla.f32 q14, q7, d5[0] @ w2 * " - "inr14\n" - "vmla.f32 q15, q7, d5[1] @ w2 * " - "inr15\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " - "to q7\n" - - /* mul r1 with w3, w4, w5, get out r0 */ - "vmla.f32 q8, q5, d3[0] @ w3 * " - "inr10\n" - "vmla.f32 q9, q5, d3[1] @ w3 * " - "inr11\n" - "vmla.f32 q10, q5, d4[0] @ w3 * " - "inr12\n" - "vmla.f32 q11, q5, d4[1] @ w3 * " - "inr13\n" - "vld1.32 {d0-d1}, [%[r2]]! @ load r2, " - "4 float\n" - "vmla.f32 q8, q6, d3[1] @ w4 * " - "inr11\n" - "vmla.f32 q9, q6, d4[0] @ w4 * " - "inr12\n" - "vmla.f32 q10, q6, d4[1] @ w4 * " - "inr13\n" - "vmla.f32 q11, q6, d5[0] @ w4 * " - "inr14\n" - "vld1.32 {d2}, [%[r2]] @ load r2, " - "2 float\n" - "vmla.f32 q8, q7, d4[0] @ w5 * " - "inr12\n" - "vmla.f32 q9, q7, d4[1] @ w5 * " - "inr13\n" - "vmla.f32 q10, q7, d5[0] @ w5 * " - "inr14\n" - "vmla.f32 q11, q7, d5[1] @ w5 * " - "inr15\n" - - /* mul r2 with w3, w4, w5, get out r1 */ - "vmla.f32 q12, q5, d0[0] @ w3 * " - "inr20\n" - "vmla.f32 q13, q5, d0[1] @ w3 * " - "inr21\n" - "vmla.f32 q14, q5, d1[0] @ w3 * " - "inr22\n" - "vmla.f32 q15, q5, d1[1] @ w3 * " - "inr23\n" - "vmla.f32 q12, q6, d0[1] @ w4 * " - "inr21\n" - "vmla.f32 q13, q6, d1[0] @ w4 * " - "inr22\n" - "vmla.f32 q14, q6, d1[1] @ w4 * " - "inr23\n" - "vmla.f32 q15, q6, d2[0] @ w4 * " - "inr24\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " - "w7, to q5, q6\n" - "vmla.f32 q12, q7, d1[0] @ w5 * " - "inr22\n" - "vmla.f32 q13, q7, d1[1] @ w5 * " - "inr23\n" - "vmla.f32 q14, q7, d2[0] @ w5 * " - "inr24\n" - "vmla.f32 q15, q7, d2[1] @ w5 * " - "inr25\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " - "to q7\n" - - "sub %[wc0], %[wc0], #144 @ wc0 - " - "144 to start address\n" - - /* mul r2 with w6, w7, w8, get out r0 */ - "vmla.f32 q8, q5, d0[0] @ w6 * " - "inr20\n" - "vmla.f32 q9, q5, d0[1] @ w6 * " - "inr21\n" - "vld1.32 {d3-d4}, [%[r3]]! @ load r3, " - "4 float\n" - "vmla.f32 q10, q5, d1[0] @ w6 * " - "inr22\n" - "vmla.f32 q11, q5, d1[1] @ w6 * " - "inr23\n" - "vmla.f32 q8, q6, d0[1] @ w7 * " - "inr21\n" - "vmla.f32 q9, q6, d1[0] @ w7 * " - "inr22\n" - "vld1.32 {d5}, [%[r3]] @ load r3, " - "2 float\n" - "vmla.f32 q10, q6, d1[1] @ w7 * " - "inr23\n" - "vmla.f32 q11, q6, d2[0] @ w7 * " - "inr24\n" - "vmla.f32 q8, q7, d1[0] @ w8 * " - "inr22\n" - "vmla.f32 q9, q7, d1[1] @ w8 * " - "inr23\n" - "vld1.32 {d0-d1}, [%[r0]]! @ load r0, " - "4 float\n" - "vmla.f32 q10, q7, d2[0] @ w8 * " - "inr24\n" - "vmla.f32 q11, q7, d2[1] @ w8 * " - "inr25\n" - "vld1.32 {d2}, [%[r0]] @ load r0, " - "2 float\n" - - /* mul r3 with w6, w7, w8, get out r1 */ - "vmla.f32 q12, q5, d3[0] @ w6 * " - "inr20\n" - "vmla.f32 q13, q5, d3[1] @ w6 * " - "inr21\n" - "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " - "r00, r01, c0~c3\n" - "vmla.f32 q14, q5, d4[0] @ w6 * " - "inr22\n" - "vmla.f32 q15, q5, d4[1] @ w6 * " - "inr23\n" - "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " - "r02, r03, c0~c3\n" - "vmla.f32 q12, q6, d3[1] @ w7 * " - "inr21\n" - "vmla.f32 q13, q6, d4[0] @ w7 * " - "inr22\n" - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " - "outr0, w0, w1, c0~c3\n" - "vmla.f32 q14, q6, d4[1] @ w7 * " - "inr23\n" - "vmla.f32 q15, q6, d5[0] @ w7 * " - "inr24\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vmla.f32 q12, q7, d4[0] @ w8 * " - "inr22\n" - "vmla.f32 q13, q7, d4[1] @ w8 * " - "inr23\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - "vmla.f32 q14, q7, d5[0] @ w8 * " - "inr24\n" - "vmla.f32 q15, q7, d5[1] @ w8 * " - "inr25\n" - - "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " - "r10, r11, c0~c3\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " - "r12, r13, c0~c3\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - "subs %[cnt], #1 @ loop " - "count--\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1), - [wc0] "+r"(wc0) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + /* load r0, r1 */ + "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" + "vld1.32 {d2}, [%[r0]] @ load r0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + /* main loop */ + "0: @ main loop\n" + /* mul r0 with w0, w1, w2, get out r0 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n" + "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n" + "vmla.f32 q9, q5, d0[1] @ w0 * inr01\n" + "vmla.f32 q10, q5, d1[0] @ w0 * inr02\n" + "vmla.f32 q11, q5, d1[1] @ w0 * inr03\n" + "vld1.32 {d3-d4}, [%[r1]]! @ load r1\n" + "vmla.f32 q8, q6, d0[1] @ w1 * inr01\n" + "vmla.f32 q9, q6, d1[0] @ w1 * inr02\n" + "vmla.f32 q10, q6, d1[1] @ w1 * inr03\n" + "vmla.f32 q11, q6, d2[0] @ w1 * inr04\n" + "vld1.32 {d5}, [%[r1]] @ load r0\n" + "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n" + "vmla.f32 q9, q7, d1[1] @ w2 * inr03\n" + "vmla.f32 q10, q7, d2[0] @ w2 * inr04\n" + "vmla.f32 q11, q7, d2[1] @ w2 * inr05\n" + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n" + /* mul r1 with w0, w1, w2, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w0 * inr10\n" + "vmla.f32 q13, q5, d3[1] @ w0 * inr11\n" + "vmla.f32 q14, q5, d4[0] @ w0 * inr12\n" + "vmla.f32 q15, q5, d4[1] @ w0 * inr13\n" + "vmla.f32 q12, q6, d3[1] @ w1 * inr11\n" + "vmla.f32 q13, q6, d4[0] @ w1 * inr12\n" + "vmla.f32 q14, q6, d4[1] @ w1 * inr13\n" + "vmla.f32 q15, q6, d5[0] @ w1 * inr14\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4\n" + "vmla.f32 q12, q7, d4[0] @ w2 * inr12\n" + "vmla.f32 q13, q7, d4[1] @ w2 * inr13\n" + "vmla.f32 q14, q7, d5[0] @ w2 * inr14\n" + "vmla.f32 q15, q7, d5[1] @ w2 * inr15\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5\n" + /* mul r1 with w3, w4, w5, get out r0 */ + "vmla.f32 q8, q5, d3[0] @ w3 * inr10\n" + "vmla.f32 q9, q5, d3[1] @ w3 * inr11\n" + "vmla.f32 q10, q5, d4[0] @ w3 * inr12\n" + "vmla.f32 q11, q5, d4[1] @ w3 * inr13\n" + "vld1.32 {d0-d1}, [%[r2]]! @ load r2\n" + "vmla.f32 q8, q6, d3[1] @ w4 * inr11\n" + "vmla.f32 q9, q6, d4[0] @ w4 * inr12\n" + "vmla.f32 q10, q6, d4[1] @ w4 * inr13\n" + "vmla.f32 q11, q6, d5[0] @ w4 * inr14\n" + "vld1.32 {d2}, [%[r2]] @ load r2\n" + "vmla.f32 q8, q7, d4[0] @ w5 * inr12\n" + "vmla.f32 q9, q7, d4[1] @ w5 * inr13\n" + "vmla.f32 q10, q7, d5[0] @ w5 * inr14\n" + "vmla.f32 q11, q7, d5[1] @ w5 * inr15\n" + /* mul r2 with w3, w4, w5, get out r1 */ + "vmla.f32 q12, q5, d0[0] @ w3 * inr20\n" + "vmla.f32 q13, q5, d0[1] @ w3 * inr21\n" + "vmla.f32 q14, q5, d1[0] @ w3 * inr22\n" + "vmla.f32 q15, q5, d1[1] @ w3 * inr23\n" + "vmla.f32 q12, q6, d0[1] @ w4 * inr21\n" + "vmla.f32 q13, q6, d1[0] @ w4 * inr22\n" + "vmla.f32 q14, q6, d1[1] @ w4 * inr23\n" + "vmla.f32 q15, q6, d2[0] @ w4 * inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n" + "vmla.f32 q12, q7, d1[0] @ w5 * inr22\n" + "vmla.f32 q13, q7, d1[1] @ w5 * inr23\n" + "vmla.f32 q14, q7, d2[0] @ w5 * inr24\n" + "vmla.f32 q15, q7, d2[1] @ w5 * inr25\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144\n" + /* mul r2 with w6, w7, w8, get out r0 */ + "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n" + "vmla.f32 q9, q5, d0[1] @ w6 * inr21\n" + "vld1.32 {d3-d4}, [%[r3]]! @ load r3\n" + "vmla.f32 q10, q5, d1[0] @ w6 * inr22\n" + "vmla.f32 q11, q5, d1[1] @ w6 * inr23\n" + "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n" + "vmla.f32 q9, q6, d1[0] @ w7 * inr22\n" + "vld1.32 {d5}, [%[r3]] @ load r3\n" + "vmla.f32 q10, q6, d1[1] @ w7 * inr23\n" + "vmla.f32 q11, q6, d2[0] @ w7 * inr24\n" + "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n" + "vmla.f32 q9, q7, d1[1] @ w8 * inr23\n" + "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" + "vmla.f32 q10, q7, d2[0] @ w8 * inr24\n" + "vmla.f32 q11, q7, d2[1] @ w8 * inr25\n" + "vld1.32 {d2}, [%[r0]] @ load r0\n" + /* mul r3 with w6, w7, w8, get out r1 */ + "vmla.f32 q12, q5, d3[0] @ w6 * inr20\n" + "vmla.f32 q13, q5, d3[1] @ w6 * inr21\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n" + "vmla.f32 q14, q5, d4[0] @ w6 * inr22\n" + "vmla.f32 q15, q5, d4[1] @ w6 * inr23\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n" + "vmla.f32 q12, q6, d3[1] @ w7 * inr21\n" + "vmla.f32 q13, q6, d4[0] @ w7 * inr22\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vmla.f32 q14, q6, d4[1] @ w7 * inr23\n" + "vmla.f32 q15, q6, d5[0] @ w7 * inr24\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vmla.f32 q12, q7, d4[0] @ w8 * inr22\n" + "vmla.f32 q13, q7, d4[1] @ w8 * inr23\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + "vmla.f32 q14, q7, d5[0] @ w8 * inr24\n" + "vmla.f32 q15, q7, d5[1] @ w8 * inr25\n" + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + "subs %[cnt], #1 @ loop count--\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc","memory","q0","q1","q2","q3", + "q4","q5","q6","q7","q8","q9", + "q10","q11","q12","q13","q14","q15"); + // clang-format on inr0 += win_round; inr1 += win_round; inr2 += win_round; @@ -602,7 +458,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, write_to_output_c4_fp32(pre_out, dout_batch, c, - c + hout_c_block, + c + OUT_C_BLOCK, h, h + h_kernel, 0, @@ -641,7 +497,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, } fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_remain_ptr; const float* inr0 = block_inr0; @@ -672,109 +528,66 @@ void conv_3x3s1_direct_fp32(const float* i_data, const float* r3 = inr3; int cnt = w_loop; + // clang-format off asm volatile( - "ldr q21, [%[ptr_out0]] \n" /* load outr0, - w0~w3*/ - "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - "2: \n" /* main loop*/ - - "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ - "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ - - "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ - "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ - "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ - "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ - - "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ - - "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ - "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ - - "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ - "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ - - "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ - "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ - - "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ - "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ - "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ - "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ - - "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ - "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ - - "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ - "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ - - "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ - - "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ - "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ - - "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ - - "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ - "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ - - "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ - - "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ - "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ - - "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ - "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ - "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ - - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; + "ldr q21, [%[ptr_out0]]\n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "2: \n" /* main loop*/ + "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ + "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ + "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ + "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ + "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ + "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ + "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ + "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ + "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ + "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ + "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ + "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ + "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ + "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ + "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ + "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ + "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ + "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ + "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ + "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ + "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ + "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ + "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ + "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ + "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ + "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ + "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ + "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ + "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ + "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ + "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0", + "v1","v2","v3","v4","v5","v6", + "v7","v8","v9","v10","v11","v12", + "v13","v14","v15","v21","v22" + ); + // clang-format on + wc0 += 9 * OUT_C_BLOCK; inr0 += win_round; inr1 += win_round; inr2 += win_round; @@ -806,181 +619,96 @@ void conv_3x3s1_direct_fp32(const float* i_data, const float* r3 = inr3; int cnt = w_loop / 2; if (cnt > 0) { + // clang-format off asm volatile( - "vld1.32 {d24-d27}, [%[ptr_out0]] @ " - "load or00, or01\n" - "vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 " - "float\n" - "vld1.32 {d10}, [%[r0]] @ load r0, 2 " - "float\n" - /* main loop */ - "0: @ main loop\n" - /* r0 * w0, w1, w2, get out r0*/ - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " - "or11\n" - "vext.32 q8, q3, q4, #1 @ r0, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r0, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r0, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r0, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r1]] @ load r1, 2 " - "float\n" - - /* r1 * w3, w4, w5, get out r0*/ - /* r1 * w0, w1, w2, get out r1*/ - "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, " - "4, 5, 6, 7\n" - "vext.32 q8, q3, q4, #1 @ r1, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r1, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r1, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r1, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, " - "5, 6, 7, 8\n" - "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, " - "1, 2, 3, 4\n" - "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, " - "6, 7, 8, 9\n" - "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r2]] @ load r2, 2 " - "float\n" - - /* r2 * w6, w7, w8, get out r0*/ - /* r2 * w3, w4, w5, get out r1*/ - "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " - "0, 1, 2, 3\n" - "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, " - "4, 5, 6, 7\n" - "vext.32 q8, q3, q4, #1 @ r2, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r2, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, " - "4, 5, 6, 7\n" - "vext.32 q10, q3, q4, #2 @ r2, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r2, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, " - "1, 2, 3, 4\n" - "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, " - "5, 6, 7, 8\n" - "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, " - "1, 2, 3, 4\n" - "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, " - "5, 6, 7, 8\n" - "vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 " - "float\n" - "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, " - "2, 3, 4, 5\n" - "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, " - "6, 7, 8, 9\n" - "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r3]] @ load r3, 2 " - "float\n" - - /* r3 * w6, w7, w8, get out r1*/ - "vext.32 q8, q3, q4, #1 @ r3, shift " - "left 1, get 1, 2, 3, 4\n" - "vext.32 q9, q4, q5, #1 @ r3, shift " - "left 1, get 5, 6, 7, 8\n" - "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, " - "4, 5, 6, 7\n" - "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, " - "or01\n" - "vext.32 q10, q3, q4, #2 @ r3, shift " - "left 2, get 2, 3, 4, 5\n" - "vext.32 q11, q4, q5, #2 @ r3, shift " - "left 2, get 6, 7, 8, 9\n" - "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, " - "0, 1, 2, 3\n" - "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, " - "4, 5, 6, 7\n" - "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " - "or01\n" - "vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 " - "float\n" - "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, " - "2, 3, 4, 5\n" - "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, " - "6, 7, 8, 9\n" - "vld1.32 {d10}, [%[r0]] @ load r0, 2 " - "float\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, " - "or11\n" - - "subs %[cnt], #1 @loop count " - "-1\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r0\n" + "vld1.32 {d10}, [%[r0]] @ load r0\n" + /* main loop */ + "0: @ main loop\n" + /* r0 * w0, w1, w2, get out r0*/ + "vld1.32 {d28-d31}, [%[ptr_out1]]@ load or10 or11\n" + "vext.32 q8, q3, q4, #1 @ r0, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r0, shift left 1\n" + "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0\n" + "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0\n" + "vext.32 q10, q3, q4, #2 @ r0, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r0, shift left 2\n" + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n" + "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0\n" + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8\n" + "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0\n" + "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0\n" + "vld1.32 {d10}, [%[r1]] @ load r1\n" + /* r1 * w3, w4, w5, get out r0*/ + /* r1 * w0, w1, w2, get out r1*/ + "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1\n" + "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1\n" + "vext.32 q8, q3, q4, #1 @ r1, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1\n" + "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1\n" + "vext.32 q10, q3, q4, #2 @ r1, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r1, shift left 2\n" + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1\n" + "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1\n" + "vld1.32 {d6-d9}, [%[r2]]! @ load r2\n" + "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1\n" + "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1\n" + "vld1.32 {d10}, [%[r2]] @ load r2\n" + /* r2 * w6, w7, w8, get out r0*/ + /* r2 * w3, w4, w5, get out r1*/ + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n" + "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2\n" + "vext.32 q8, q3, q4, #1 @ r2, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r2, shift left 1\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2\n" + "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2\n" + "vext.32 q10, q3, q4, #2 @ r2, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r2, shift left 2\n" + "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2\n" + "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2\n" + "vld1.32 {d6-d9}, [%[r3]]! @ load r3\n" + "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2\n" + "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2\n" + "vld1.32 {d10}, [%[r3]] @ load r3\n" + /* r3 * w6, w7, w8, get out r1*/ + "vext.32 q8, q3, q4, #1 @ r3, shift left 1\n" + "vext.32 q9, q4, q5, #1 @ r3, shift left 1\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3\n" + "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, or01\n" + "vext.32 q10, q3, q4, #2 @ r3, shift left 2\n" + "vext.32 q11, q4, q5, #2 @ r3, shift left 2\n" + "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3\n" + "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00,or01\n" + "vld1.32 {d6-d9}, [%[r0]]! @ load r3\n" + "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3\n" + "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3\n" + "vld1.32 {d10}, [%[r0]] @ load r0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, or11\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc","memory","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on r0 -= 8; } //! deal with remain ow diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc index dbe0706f58..8260718a50 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -24,6 +24,39 @@ namespace lite { namespace arm { namespace math { +const int OUT_C_BLOCK = 4; +const int OUT_H_BLOCK = 2; +const int OUT_W_BLOCK = 4; + +size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx) { + auto dim_in = param.x->dims(); + auto dim_out = param.output->dims(); + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / sizeof(float); + const int pad_w = param.paddings[1]; + const int pad_h = param.paddings[0]; + int ow = dim_out[3]; + int oh = dim_out[2]; + int ic = dim_in[1]; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); + const int win_round = wout_round * 2 /*stride_w*/ + 1; + const int hin_r_block = OUT_H_BLOCK * 2 /*stride_h*/ + 1; + + int hout_r_block = + (llc_size - 2 * wout_round * ic - ic) / + ((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads); + hout_r_block = hout_r_block > oh ? oh : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; + + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + + return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size); +} + void conv_3x3s2_direct_fp32(const float* i_data, float* o_data, int bs, @@ -44,53 +77,50 @@ void conv_3x3s2_direct_fp32(const float* i_data, int l2_size = ctx->llc_size() / sizeof(float); const int pad_w = param.paddings[1]; const int pad_h = param.paddings[0]; - const int hout_c_block = 4; - const int hout_r_kernel = 2; - const int wout_block = 4; - const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block; + const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round * 2 /*stride_w*/ + 1; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; //! get h block - //! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block + //! win_round * ic * hin_r_block + wout_round * OUT_C_BLOCK * hout_r_block //! * threads = l2_size //! win_round = 2 * wout_round + 1 //! hin_r_block = 2 * hout_r_block + 1 int hout_r_block = (l2_size - 2 * wout_round * ic - ic) / - ((4 * wout_round + 2) * ic + wout_round * hout_c_block * threads); + ((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads); hout_r_block = hout_r_block > oh ? oh : hout_r_block; - hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; - hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; + hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; + int in_len = win_round * ic; + int pre_in_size = hin_r_block * in_len; + int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; + float* tmp_work_space = ctx->workspace_data(); float ptr_zero[win_round]; // NOLINT memset(ptr_zero, 0, sizeof(float) * win_round); float ptr_write[wout_round]; // NOLINT - int in_len = win_round * ic; - int pre_in_size = hin_r_block * in_len; - int pre_out_size = hout_c_block * hout_r_block * wout_round; - //! l2_cache start float* pre_din = tmp_work_space; int size_in_channel = win * ih; int size_out_channel = ow * oh; - int w_stride = ic * 9; /*kernel_w * kernel_h*/ - int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * + int w_stride = ic * 9; /*kernel_w * kernel_h*/ + int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h * int ws = -pad_w; int we = ws + win_round; int w_loop = wout_round / 4; - int c_remain = oc - (oc / hout_c_block) * hout_c_block; - int c_round_down = (oc / hout_c_block) * hout_c_block; + int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; + int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; - int out_row_stride = hout_c_block * wout_round; + int out_row_stride = OUT_C_BLOCK * wout_round; for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; @@ -114,7 +144,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, const float* cblock_inr4 = cblock_inr3 + in_len; #pragma omp parallel for num_threads(threads) - for (int c = 0; c < c_round_down; c += hout_c_block) { + for (int c = 0; c < c_round_down; c += OUT_C_BLOCK) { #ifdef ARM_WITH_OMP float* pre_out = pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; @@ -133,9 +163,9 @@ void conv_3x3s2_direct_fp32(const float* i_data, bias_ptr = bias + c; } fill_packed_biasc4( - pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); + pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel); - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_c; const float* inr0 = block_inr0; @@ -168,205 +198,133 @@ void conv_3x3s2_direct_fp32(const float* i_data, const float* r4 = inr4; int cnt = w_loop; + // clang-format off asm volatile( - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, - outr01*/ - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - - "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "2: \n" /* main loop*/ - /* r0, r2, mul w0, get out r0, r1 */ - "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ - "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ - "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ - "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ - "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ - "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ - "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ - "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ - "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ - "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ - - "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ - - /* r2 mul w6, get out r0*/ - "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ - "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ - "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ - - "ldr d11, [%[r1]] \n" /* load input r1, 9th - element*/ - - /* r0, r2, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ - "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ - "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ - "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ - "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ - "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ - "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ - "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ - - "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ - - /* r2 mul w7, get out r0 */ - "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ - "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ - "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ - - "ldr d13, [%[r3]] \n" /* load input r3, 9th - element*/ - - /* r0, r2, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ - "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ - "fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 * - r0[8]*/ - "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ - "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ - "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ - "fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 * - r2[8]*/ - - "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ - - /* r2, mul w8, get out r0 */ - "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ - "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ - "fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 * - r2[8]*/ - - "ldr d14, [%[r4]] \n" /* load input r4, 9th - element*/ - - /* r1, r3, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ - "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ - "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ - "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ - "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ - "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ - "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ - - "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ - - /* r1, r3, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ - "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ - "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ - "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ - "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ - "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ - "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ - "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ - - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - - /* r1, r3, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ - "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ - "fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 * - r1[8]*/ - - "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ - "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ - - "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ - "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ - "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ - "fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 * - r3[8]*/ - - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ - - /* r4, mul w6, get out r1 */ - "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ - "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ - "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ - "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ - - "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ - - /* r4, mul w7, get out r1 */ - "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ - "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ - "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ - "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ - - "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ - - /* r4, mul w8, get out r1 */ - "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ - "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ - "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ - "fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 * - r4[8]*/ - - "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ - - "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ - "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ - - "bne 2b \n" /* jump to main loop*/ - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); - - wc0 += 9 * hout_c_block; + "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00, outr01*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + "ldp q0, q1, [%[r0]], #32\n" /* load input r0*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ + "ldp q4, q5, [%[r2]], #32\n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "2:\n" /* main loop*/ + /* r0, r2, mul w0, get out r0, r1 */ + "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ + "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ + "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ + "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ + "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ + "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ + "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ + "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ + "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ + "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ + "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ + /* r2 mul w6, get out r0*/ + "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ + "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ + "fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ + "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ + "ldr d11, [%[r1]]\n" /* load input r1, 9th element*/ + /* r0, r2, mul w1, get out r0, r1 */ + "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ + "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ + "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ + "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ + "fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ + "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ + "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ + "fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ + "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ + /* r2 mul w7, get out r0 */ + "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ + "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ + "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ + "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ + "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/ + /* r0, r2, mul w2, get out r0, r1 */ + "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ + "fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ + "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ + "fmla v18.4s , %[w2].4s, v10.s[0]\n"/* outr03 = w2 * r0[8]*/ + "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ + "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ + "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ + "fmla v22.4s , %[w2].4s, v12.s[0]\n"/* outr13 = w2 * r2[8]*/ + "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ + /* r2, mul w8, get out r0 */ + "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ + "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ + "fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ + "fmla v18.4s , %[w8].4s, v12.s[0]\n"/* outr03 = w8 * r2[8]*/ + "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/ + /* r1, r3, mul w3, get out r0, r1 */ + "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ + "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ + "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ + "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ + "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ + "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ + "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ + "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ + "ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ + /* r1, r3, mul w4, get out r0, r1 */ + "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ + "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ + "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ + "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ + "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ + "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ + "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ + "fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ + /* r1, r3, mul w5, get out r0, r1 */ + "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ + "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ + "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ + "fmla v18.4s , %[w5].4s, v11.s[0]\n"/* outr03 = w5 * r1[8]*/ + "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ + "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ + "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ + "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ + "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ + "fmla v22.4s , %[w5].4s, v13.s[0]\n"/* outr13 = w5 * r3[8]*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ + /* r4, mul w6, get out r1 */ + "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/ + "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/ + "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/ + "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/ + "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ + /* r4, mul w7, get out r1 */ + "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/ + "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/ + "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/ + "fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/ + "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ + /* r4, mul w8, get out r1 */ + "fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/ + "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/ + "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/ + "fmla v22.4s , %[w8].4s, v14.s[0]\n"/* outr13 = w8 * r4[8]*/ + "subs %w[cnt], %w[cnt], #1\n" /*loop count -1*/ + "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ + "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ + "bne 2b \n" /* jump to main loop*/ + : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), + [w1] "w"(w1), [w2] "w"(w2), + [w3] "w"(w3), [w4] "w"(w4), + [w5] "w"(w5), [w6] "w"(w6), + [w7] "w"(w7), [w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3","v4", + "v5","v6","v7","v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19","v20","v21","v22"); + // clang-format on + wc0 += 9 * OUT_C_BLOCK; inr0 += win_round; inr1 += win_round; inr2 += win_round; @@ -387,285 +345,142 @@ void conv_3x3s2_direct_fp32(const float* i_data, const float* r4 = inr4; int cnt = w_loop; + // clang-format off asm volatile( - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ " - "load outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - /* load r0, r2 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " - "8 float\n" - "vld1.32 {d8}, [%[r0]] @ load r0, " - "9th float\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - /* main loop */ - "0: @ main " - "loop\n" - /* mul r0, with w0, w1, w2 */ - "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " - "outr1, w0, w1, c0~c3\n" - "vmla.f32 q8, q5, d0[0] @ w0 * " - "inr00\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load " - "outr1, w2, w3, c0~c3\n" - "vmla.f32 q9, q5, d1[0] @ w0 * " - "inr02\n" - "vmla.f32 q10, q5, d2[0] @ w0 * " - "inr04\n" - "vmla.f32 q11, q5, d3[0] @ w0 * " - "inr06\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w1 * " - "inr01\n" - "vmla.f32 q9, q6, d1[1] @ w1 * " - "inr03\n" - "vmla.f32 q10, q6, d2[1] @ w1 * " - "inr05\n" - "vmla.f32 q11, q6, d3[1] @ w1 * " - "inr07\n" - "vld1.32 {d9}, [%[r2]] @ load r2, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w2 * " - "inr02\n" - "vmla.f32 q9, q7, d2[0] @ w2 * " - "inr04\n" - "vmla.f32 q10, q7, d3[0] @ w2 * " - "inr06\n" - "vmla.f32 q11, q7, d8[0] @ w2 * " - "inr08\n" - - "sub %[r2], %[r2], #32 @ r2 - 32, " - "load r2 twice\n" - - /* mul r2, with w0, w1, w2 */ - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w0 * " - "inr20\n" - "vmla.f32 q13, q5, d5[0] @ w0 * " - "inr22\n" - "vmla.f32 q14, q5, d6[0] @ w0 * " - "inr24\n" - "vmla.f32 q15, q5, d7[0] @ w0 * " - "inr26\n" - "vld1.32 {d8}, [%[r1]] @ load r1, " - "9th float\n" - "vmla.f32 q12, q6, d4[1] @ w1 * " - "inr21\n" - "vmla.f32 q13, q6, d5[1] @ w1 * " - "inr23\n" - "vmla.f32 q14, q6, d6[1] @ w1 * " - "inr25\n" - "vmla.f32 q15, q6, d7[1] @ w1 * " - "inr27\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " - "w4, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w2 * " - "inr22\n" - "vmla.f32 q13, q7, d6[0] @ w2 * " - "inr24\n" - "vmla.f32 q14, q7, d7[0] @ w2 * " - "inr26\n" - "vmla.f32 q15, q7, d9[0] @ w2 * " - "inr28\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " - "to q7\n" - - /* mul r1, with w3, w4, w5 */ - "vmla.f32 q8, q5, d0[0] @ w3 * " - "inr10\n" - "vmla.f32 q9, q5, d1[0] @ w3 * " - "inr12\n" - "vmla.f32 q10, q5, d2[0] @ w3 * " - "inr14\n" - "vmla.f32 q11, q5, d3[0] @ w3 * " - "inr16\n" - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w4 * " - "inr11\n" - "vmla.f32 q9, q6, d1[1] @ w4 * " - "inr13\n" - "vmla.f32 q10, q6, d2[1] @ w4 * " - "inr15\n" - "vmla.f32 q11, q6, d3[1] @ w4 * " - "inr17\n" - "vld1.32 {d9}, [%[r3]] @ load r3, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w5 * " - "inr12\n" - "vmla.f32 q9, q7, d2[0] @ w5 * " - "inr14\n" - "vmla.f32 q10, q7, d3[0] @ w5 * " - "inr16\n" - "vmla.f32 q11, q7, d8[0] @ w5 * " - "inr18\n" - - "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " - "- 32, to start address\n" - - /* mul r3, with w3, w4, w5 */ - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w3 * " - "inr30\n" - "vmla.f32 q13, q5, d5[0] @ w3 * " - "inr32\n" - "vmla.f32 q14, q5, d6[0] @ w3 * " - "inr34\n" - "vmla.f32 q15, q5, d7[0] @ w3 * " - "inr36\n" - "vld1.32 {d8}, [%[r2]] @ load r2, " - "9th float\n" - "vmla.f32 q12, q6, d4[1] @ w4 * " - "inr31\n" - "vmla.f32 q13, q6, d5[1] @ w4 * " - "inr33\n" - "vmla.f32 q14, q6, d6[1] @ w4 * " - "inr35\n" - "vmla.f32 q15, q6, d7[1] @ w4 * " - "inr37\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, " - "w7, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w5 * " - "inr32\n" - "vmla.f32 q13, q7, d6[0] @ w5 * " - "inr34\n" - "vmla.f32 q14, q7, d7[0] @ w5 * " - "inr36\n" - "vmla.f32 q15, q7, d9[0] @ w5 * " - "inr38\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w8, " - "to q7\n" - - /* mul r2, with w6, w7, w8 */ - "vmla.f32 q8, q5, d0[0] @ w6 * " - "inr20\n" - "vmla.f32 q9, q5, d1[0] @ w6 * " - "inr22\n" - "vmla.f32 q10, q5, d2[0] @ w6 * " - "inr24\n" - "vmla.f32 q11, q5, d3[0] @ w6 * " - "inr26\n" - "vld1.32 {d4-d7}, [%[r4]]! @ load r4, " - "8 float\n" - "vmla.f32 q8, q6, d0[1] @ w7 * " - "inr21\n" - "vmla.f32 q9, q6, d1[1] @ w7 * " - "inr23\n" - "vmla.f32 q10, q6, d2[1] @ w7 * " - "inr25\n" - "vmla.f32 q11, q6, d3[1] @ w7 * " - "inr27\n" - "vld1.32 {d9}, [%[r4]] @ load r4, " - "9th float\n" - "vmla.f32 q8, q7, d1[0] @ w8 * " - "inr22\n" - "vmla.f32 q9, q7, d2[0] @ w8 * " - "inr24\n" - "vmla.f32 q10, q7, d3[0] @ w8 * " - "inr26\n" - "vmla.f32 q11, q7, d8[0] @ w8 * " - "inr28\n" - - "sub %[wc0], %[wc0], #144 @ wc0 - " - "144 to start address\n" - - /* mul r4, with w6, w7, w8 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, " - "8 float\n" - "vmla.f32 q12, q5, d4[0] @ w3 * " - "inr40\n" - "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save " - "r00, r01, c0~c3\n" - "vmla.f32 q13, q5, d5[0] @ w3 * " - "inr42\n" - "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save " - "r02, r03, c0~c3\n" - "vmla.f32 q14, q5, d6[0] @ w3 * " - "inr44\n" - "vmla.f32 q15, q5, d7[0] @ w3 * " - "inr46\n" - "vld1.32 {d8}, [%[r0]] @ load " - "r0, 9th float\n" - "vmla.f32 q12, q6, d4[1] @ w4 * " - "inr41\n" - "vmla.f32 q13, q6, d5[1] @ w4 * " - "inr43\n" - "vmla.f32 q14, q6, d6[1] @ w4 * " - "inr45\n" - "vmla.f32 q15, q6, d7[1] @ w4 * " - "inr47\n" - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " - "w1, to q5, q6\n" - "vmla.f32 q12, q7, d5[0] @ w5 * " - "inr42\n" - "vmla.f32 q13, q7, d6[0] @ w5 * " - "inr44\n" - "vmla.f32 q14, q7, d7[0] @ w5 * " - "inr46\n" - "vmla.f32 q15, q7, d9[0] @ w5 * " - "inr48\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " - "to q7\n" - - "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save " - "r10, r11, c0~c3\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save " - "r12, r13, c0~c3\n" - - "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load " - "outr0, w0, w1, c0~c3\n" - "vld1.32 {d20-d23}, [%[ptr_out0]] @ load " - "outr0, w2, w3, c0~c3\n" - - "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " - "- 32, to start address\n" - - "subs %[cnt], #1 @ loop " - "count--\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1), - [wc0] "+r"(wc0) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + /* load weights */ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + /* load r0, r2 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n" + "vld1.32 {d8}, [%[r0]] @ load r0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 -32\n" + /* main loop */ + "0: @ main loop\n" + /* mul r0, with w0, w1, w2 */ + "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n" + "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n" + "vmla.f32 q9, q5, d1[0] @ w0 * inr02\n" + "vmla.f32 q10, q5, d2[0] @ w0 * inr04\n" + "vmla.f32 q11, q5, d3[0] @ w0 * inr06\n" + "vld1.32 {d4-d7}, [%[r2]]! @ load r2\n" + "vmla.f32 q8, q6, d0[1] @ w1 * inr01\n" + "vmla.f32 q9, q6, d1[1] @ w1 * inr03\n" + "vmla.f32 q10, q6, d2[1] @ w1 * inr05\n" + "vmla.f32 q11, q6, d3[1] @ w1 * inr07\n" + "vld1.32 {d9}, [%[r2]] @ load r2, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n" + "vmla.f32 q9, q7, d2[0] @ w2 * inr04\n" + "vmla.f32 q10, q7, d3[0] @ w2 * inr06\n" + "vmla.f32 q11, q7, d8[0] @ w2 * inr08\n" + "sub %[r2], %[r2], #32 @ r2 - 32\n" + /* mul r2, with w0, w1, w2 */ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1\n" + "vmla.f32 q12, q5, d4[0] @ w0 * inr20\n" + "vmla.f32 q13, q5, d5[0] @ w0 * inr22\n" + "vmla.f32 q14, q5, d6[0] @ w0 * inr24\n" + "vmla.f32 q15, q5, d7[0] @ w0 * inr26\n" + "vld1.32 {d8}, [%[r1]] @ load r1, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w1 * inr21\n" + "vmla.f32 q13, q6, d5[1] @ w1 * inr23\n" + "vmla.f32 q14, q6, d6[1] @ w1 * inr25\n" + "vmla.f32 q15, q6, d7[1] @ w1 * inr27\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4, to q5, q6\n" + "vmla.f32 q12, q7, d5[0] @ w2 * inr22\n" + "vmla.f32 q13, q7, d6[0] @ w2 * inr24\n" + "vmla.f32 q14, q7, d7[0] @ w2 * inr26\n" + "vmla.f32 q15, q7, d9[0] @ w2 * inr28\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, to q7\n" + /* mul r1, with w3, w4, w5 */ + "vmla.f32 q8, q5, d0[0] @ w3 * inr10\n" + "vmla.f32 q9, q5, d1[0] @ w3 * inr12\n" + "vmla.f32 q10, q5, d2[0] @ w3 * inr14\n" + "vmla.f32 q11, q5, d3[0] @ w3 * inr16\n" + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 8 float\n" + "vmla.f32 q8, q6, d0[1] @ w4 * inr11\n" + "vmla.f32 q9, q6, d1[1] @ w4 * inr13\n" + "vmla.f32 q10, q6, d2[1] @ w4 * inr15\n" + "vmla.f32 q11, q6, d3[1] @ w4 * inr17\n" + "vld1.32 {d9}, [%[r3]] @ load r3, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w5 * inr12\n" + "vmla.f32 q9, q7, d2[0] @ w5 * inr14\n" + "vmla.f32 q10, q7, d3[0] @ w5 * inr16\n" + "vmla.f32 q11, q7, d8[0] @ w5 * inr18\n" + "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n" + /* mul r3, with w3, w4, w5 */ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2\n" + "vmla.f32 q12, q5, d4[0] @ w3 * inr30\n" + "vmla.f32 q13, q5, d5[0] @ w3 * inr32\n" + "vmla.f32 q14, q5, d6[0] @ w3 * inr34\n" + "vmla.f32 q15, q5, d7[0] @ w3 * inr36\n" + "vld1.32 {d8}, [%[r2]] @ load r2, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * inr31\n" + "vmla.f32 q13, q6, d5[1] @ w4 * inr33\n" + "vmla.f32 q14, q6, d6[1] @ w4 * inr35\n" + "vmla.f32 q15, q6, d7[1] @ w4 * inr37\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n" + "vmla.f32 q12, q7, d5[0] @ w5 * inr32\n" + "vmla.f32 q13, q7, d6[0] @ w5 * inr34\n" + "vmla.f32 q14, q7, d7[0] @ w5 * inr36\n" + "vmla.f32 q15, q7, d9[0] @ w5 * inr38\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n" + /* mul r2, with w6, w7, w8 */ + "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n" + "vmla.f32 q9, q5, d1[0] @ w6 * inr22\n" + "vmla.f32 q10, q5, d2[0] @ w6 * inr24\n" + "vmla.f32 q11, q5, d3[0] @ w6 * inr26\n" + "vld1.32 {d4-d7}, [%[r4]]! @ load r4\n" + "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n" + "vmla.f32 q9, q6, d1[1] @ w7 * inr23\n" + "vmla.f32 q10, q6, d2[1] @ w7 * inr25\n" + "vmla.f32 q11, q6, d3[1] @ w7 * inr27\n" + "vld1.32 {d9}, [%[r4]] @ load r4, 9th float\n" + "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n" + "vmla.f32 q9, q7, d2[0] @ w8 * inr24\n" + "vmla.f32 q10, q7, d3[0] @ w8 * inr26\n" + "vmla.f32 q11, q7, d8[0] @ w8 * inr28\n" + "sub %[wc0], %[wc0], #144 @ wc0 - 144\n" + /* mul r4, with w6, w7, w8 */ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n" + "vmla.f32 q12, q5, d4[0] @ w3 * inr40\n" + "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n" + "vmla.f32 q13, q5, d5[0] @ w3 * inr42\n" + "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n" + "vmla.f32 q14, q5, d6[0] @ w3 * inr44\n" + "vmla.f32 q15, q5, d7[0] @ w3 * inr46\n" + "vld1.32 {d8}, [%[r0]] @ load r0, 9th float\n" + "vmla.f32 q12, q6, d4[1] @ w4 * inr41\n" + "vmla.f32 q13, q6, d5[1] @ w4 * inr43\n" + "vmla.f32 q14, q6, d6[1] @ w4 * inr45\n" + "vmla.f32 q15, q6, d7[1] @ w4 * inr47\n" + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" + "vmla.f32 q12, q7, d5[0] @ w5 * inr42\n" + "vmla.f32 q13, q7, d6[0] @ w5 * inr44\n" + "vmla.f32 q14, q7, d7[0] @ w5 * inr46\n" + "vmla.f32 q15, q7, d9[0] @ w5 * inr48\n" + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" + "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" + "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" + "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" + "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" + "subs %[cnt], #1 @ loop count--\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), + [r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1), + [wc0] "+r"(wc0) + : + : "cc","memory","q0","q1","q2","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on inr0 += win_round; inr1 += win_round; @@ -684,7 +499,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, write_to_output_c4_fp32(pre_out, dout_batch, c, - c + hout_c_block, + c + OUT_C_BLOCK, h, h + h_kernel, 0, @@ -721,7 +536,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, } fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); - for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_c; const float* inr0 = block_inr0; @@ -755,158 +570,80 @@ void conv_3x3s2_direct_fp32(const float* i_data, const float* r4 = inr4; int cnt = w_loop; + // clang-format off asm volatile( - "ldr q21, [%[ptr_out0]] \n" /* load outr00, - outr01, - outr02, - outr03*/ - - "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "2: \n" /* main loop*/ + "ldr q21, [%[ptr_out0]]\n" /* load outr00-outr03*/ + "ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/ + "ldr d10, [%[r0]]\n"/* load input r0, 9th element*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32\n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "2:\n" /* main loop*/ /* r0, r2, mul w0, get out r0, r1 */ - "ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11, - outr12, outr13*/ - - "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2, - 4, 6]*/ - "fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2, - 4, 6]*/ - - "ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/ - + "ldr q22, [%[ptr_out1]]\n" /* load outr10 - outr13*/ + "fmla v21.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0*/ + "fmla v22.4s , %[w0].4s, v4.4s\n" /* outr1 = w0 * r2*/ + "ld2 {v2.4s, v3.4s}, [%[r1]], #32\n" /* load input r1*/ /* r2 mul w6, get out r0*/ - "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2, - 4, 6]*/ - "ldr d11, [%[r1]] \n" /* load input r1, 9th - element*/ - + "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr0 = w6 * r2*/ + "ldr d11, [%[r1]]\n" /* load input r1, 9th element*/ /* shift left 1 */ "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ - /* r0, r2, mul w1, get out r0, r1 */ - "fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3, - 5, 7]*/ - "fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3, - 5, 7]*/ - - "ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/ - + "fmla v21.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0*/ + "fmla v22.4s , %[w1].4s, v5.4s\n" /* outr1 = w1 * r2*/ + "ld2 {v6.4s, v7.4s}, [%[r3]], #32\n" /* load input r3*/ /* r2 mul w7, get out r0 */ - "fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1, - 3, 5, 7]*/ - - "ldr d13, [%[r3]] \n" /* load input r3, 9th - element*/ - + "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr00 = w7 * r2*/ + "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/ /* r0, r2, mul w2, get out r0, r1 */ - "fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4, - 6, 8]*/ - "fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4, - 6, 8]*/ - - "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ - + "fmla v21.4s , %[w2].4s, v15.4s\n" /* outr0 = w2 * r0*/ + "fmla v22.4s , %[w2].4s, v16.4s\n" /* outr1 = w2 * r2*/ + "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/ /* r2, mul w8, get out r0 */ - "fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2, - 4, 6, 8]*/ - - "ldr d14, [%[r4]] \n" /* load input r4, 9th - element*/ - + "fmla v21.4s , %[w8].4s, v16.4s\n" /* outr00 = w8 * r2*/ + "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/ /* r1, r3, mul w3, get out r0, r1 */ - "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2, - 4, 6]*/ - "fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2, - 4, 6]*/ - + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr0 = w3 * r1*/ + "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr1 = w3 * r3*/ /* shift left 1 */ "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ - - "ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ - + "ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/ /* r1, r3, mul w4, get out r0, r1 */ - "fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3, - 5, 7]*/ - "fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3, - 5, 7]*/ - - "ldr d10, [%[r0]] \n" /* load input r0, 9th - element*/ - + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr0 = w4 * r1*/ + "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr1 = w4 * r3*/ + "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/ /* r1, r3, mul w5, get out r0, r1 */ - "fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/ - "fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/ - - "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ - "ldr d12, [%[r2]] \n" /* load input r2, 9th - element*/ - "str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/ - + "fmla v21.4s , %[w5].4s, v15.4s\n" /* outr0 = w5 * r1[2]*/ + "fmla v22.4s , %[w5].4s, v16.4s\n" /* outr1 = w5 * r1[4]*/ + "ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ + "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/ + "str q21, [%[ptr_out0]], #16\n" /* save outr00, outr01*/ /* r4, mul w6, get out r1 */ - "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2, - 4, 6]*/ - + "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4*/ "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ "ldr q21, [%[ptr_out0]] \n" /* load outr0*/ - /* r4, mul w7, get out r1 */ - "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3, - 5, 7]*/ - + "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4*/ /* r4, mul w8, get out r1 */ - "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4, - 6, 8]*/ - + "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4*/ "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ "bne 2b \n" /* jump to main loop*/ - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), + [r0] "+r"(r0),[r1] "+r"(r1), + [r2] "+r"(r2),[r3] "+r"(r3), [r4] "+r"(r4), [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v21", - "v22"); - + : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), + [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), + [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) + : "cc","memory","v0","v1","v2","v3", + "v4","v5","v6","v7","v8","v9","v10","v11", + "v12","v13","v14","v15","v16","v21","v22"); + // clang-format on wc0 += 36; inr0 += win_round; inr1 += win_round; @@ -944,184 +681,92 @@ void conv_3x3s2_direct_fp32(const float* i_data, int cnt = w_loop / 2; if (cnt > 0) { + // clang-format off asm volatile( - /* main loop */ - "0: @ " - "main loop\n" - "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " - "or01\n" - "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " - "or11\n" - "vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 " - "float, interleave\n" - "vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 " - "float, interleave\n" - "vld1.32 {d22}, [%[r2]] @ load 16th " - "float\n" - - /* r2 * w2, r2 * w0, get or0, or1 */ - "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, " - "9, 11, 13, 15\n" - "vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 " - "float, interleave\n" - "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, " - "9, 11, 13, 15\n" - - "vext.32 q4, q3, q5, #1 @ r2, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r2, shift " - "left 1, get 10, 12, 14, 16\n" - - "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, " - "8, 10, 12, 14\n" - "vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 " - "float, interleave\n" - "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, " - "8, 10, 12, 14\n" - - "vld1.32 {d22}, [%[r0]] @ load 16th " - "float\n" - - "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, " - "2, 4, 6, 8\n" - "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, " - "2, 4, 6, 8\n" - "vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 " - "float, interleave\n" - "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, " - "10, 12, 14, 16\n" - "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, " - "10, 12, 14, 16\n" - "vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 " - "float, interleave\n" - - /* r0 * w0, get or0, r3 * w1, get or1*/ - "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, " - "9, 11, 13, 15\n" - "vext.32 q8, q7, q9, #1 @ r0, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q10, q9, q11, #1 @ r0, shift " - "left 1, get 10, 12, 14, 16\n" - "vld1.32 {d22}, [%[r3]] @ load 16th " - "float\n" - "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, " - "9, 11, 13, 15\n" - - "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, " - "8, 10, 12, 14\n" - "vext.32 q4, q3, q5, #1 @ r3, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r3, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, " - "8, 10, 12, 14\n" - - "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " - "2, 4, 6, 8\n" - "vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 " - "float, interleave\n" - "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, " - "10, 12, 14, 16\n" - "vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 " - "float, interleave\n" - "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, " - "2, 4, 6, 8\n" - "vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 " - "float, interleave\n" - "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, " - "10, 12, 14, 16\n" - "vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 " - "float, interleave\n" - - "vld1.32 {d22}, [%[r1]] @ load 16th " - "float\n" - - /* r1 * w1, get or0, r4 * w2, get or1 */ - "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " - "1, 3, 5, 7\n" - "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, " - "9, 11, 13, 15\n" - "vext.32 q8, q7, q9, #1 @ r1, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q10, q9, q11, #1 @ r1, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, " - "1, 3, 5, 7\n" - "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, " - "9, 11, 13, 15\n" - "vld1.32 {d22}, [%[r4]] @ load 16th " - "float\n" - - "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, " - "0, 2, 4, 6\n" - "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, " - "8, 10, 12, 14\n" - "vext.32 q4, q3, q5, #1 @ r1, shift " - "left 1, get 2, 4, 6, 8\n" - "vext.32 q6, q5, q11, #1 @ r1, shift " - "left 1, get 10, 12, 14, 16\n" - "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, " - "0, 2, 4, 6\n" - "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, " - "8, 10, 12, 14\n" - - "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, " - "2, 4, 6, 8\n" - "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, " - "10, 12, 14, 16\n" - "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, " - "2, 4, 6, 8\n" - "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, " - "10, 12, 14, 16\n" - - "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" - "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" - - "subs %[cnt], #1 @loop count " - "-1\n" - "bne 0b @ jump to " - "main loop\n" - - : [cnt] "+r"(cnt), - [r0] "+r"(r0), - [r1] "+r"(r1), - [r2] "+r"(r2), - [r3] "+r"(r3), - [r4] "+r"(r4), - [ptr_out0] "+r"(ptr_out0), - [ptr_out1] "+r"(ptr_out1) - : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + /* main loop */ + "0: @ main loop\n" + "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n" + "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, or11\n" + "vld2.32 {d6-d9}, [%[r2]]! @ load r2\n" + "vld2.32 {d10-d13}, [%[r2]]! @ load r2\n" + "vld1.32 {d22}, [%[r2]] @ load 16th float\n" + /* r2 * w2, r2 * w0, get or0, or1 */ + "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2\n" + "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2\n " + "vld2.32 {d14-d17}, [%[r0]]! @ load r0\n" + "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2\n" + "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2\n" + "vext.32 q4, q3, q5, #1 @ r2, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r2, shift left 1\n" + "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n" + "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2\n" + "vld2.32 {d18-d21}, [%[r0]]! @ load r0\n" + "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2\n" + "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2\n" + "vld1.32 {d22}, [%[r0]] @ load 16th float\n" + "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2\n" + "vld2.32 {d6-d9}, [%[r3]]! @ load r3\n" + "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2\n" + "vmla.f32 q15, q6, %f[w0][0] @ w02 * r2\n" + "vld2.32 {d10-d13}, [%[r3]]! @ load r3\n" + /* r0 * w0, get or0, r3 * w1, get or1*/ + "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n" + "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0\n" + "vext.32 q8, q7, q9, #1 @ r0, shift left 1\n" + "vext.32 q10, q9, q11, #1 @ r0, shift left 1\n" + "vld1.32 {d22}, [%[r3]] @ load 16th float\n" + "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3\n" + "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3\n" + "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0\n" + "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0\n" + "vext.32 q4, q3, q5, #1 @ r3, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r3, shift left 1\n" + "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3\n" + "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3\n" + "vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, " + "2, 4, 6, 8\n" + "vld2.32 {d14-d17}, [%[r1]]! @ load r1\n" + "vmla.f32 q13, q10,%f[w0][0] @ w02 * r0\n" + "vld2.32 {d18-d21}, [%[r1]]! @ load r1\n" + "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3\n" + "vld2.32 {d6-d9}, [%[r4]]! @ load r4\n" + "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3\n" + "vld2.32 {d10-d13}, [%[r4]]! @ load r4\n" + "vld1.32 {d22}, [%[r1]] @ load 16th float\n" + /* r1 * w1, get or0, r4 * w2, get or1 */ + "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n" + "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1\n" + "vext.32 q8, q7, q9, #1 @ r1, shift left 1\n" + "vext.32 q10, q9, q11, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4\n" + "vmla.f32 q15, q6, %e[w2][1] @ w21 * r4\n" + "vld1.32 {d22}, [%[r4]] @ load 16th float\n" + "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1\n" + "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1\n" + "vext.32 q4, q3, q5, #1 @ r1, shift left 1\n" + "vext.32 q6, q5, q11, #1 @ r1, shift left 1\n" + "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4\n" + "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4\n" + "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1\n" + "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4\n" + "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4\n" + "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n" + "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ jump to main loop\n" + : [cnt] "+r"(cnt), + [r0] "+r"(r0),[r1] "+r"(r1),[r2] "+r"(r2), + [r3] "+r"(r3),[r4] "+r"(r4), + [ptr_out0] "+r"(ptr_out0), + [ptr_out1] "+r"(ptr_out1) + : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) + : "cc","memory","q3","q4", + "q5","q6","q7","q8","q9","q10", + "q11","q12","q13","q14","q15" + ); + // clang-format on } //! deal with remain ow if (w_loop & 1) { diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index c8a302881c..c5baa31e14 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -23,6 +23,9 @@ namespace lite { namespace arm { namespace math { +/// conv 3x3s1 +size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx); void conv_3x3s1_direct_fp32(const float* din, float* dout, int num, @@ -53,6 +56,9 @@ void conv_3x3s1_direct_int8(const int8_t* din, ARMContext* ctx, const float* scale); +/// conv3x3s2 +size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param, + ARMContext* ctx); void conv_3x3s2_direct_fp32(const float* din, float* dout, int num, diff --git a/lite/core/device_info.cc b/lite/core/device_info.cc index 078464ccf4..c150b2b177 100644 --- a/lite/core/device_info.cc +++ b/lite/core/device_info.cc @@ -1104,13 +1104,13 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) { SetCacheInfo(0, 1, l1size); SetCacheInfo(1, 1, l2size); SetCacheInfo(2, 1, l3size); - workspace_.Resize({2 * (l1size + l2size)}); + workspace_.Resize({llc_size()}); + workspace_.mutable_data(); } -bool DeviceInfo::ExtendWorkspace(int size) { +bool DeviceInfo::ExtendWorkspace(size_t size) { workspace_.Resize({size + llc_size()}); - workspace_.mutable_data(); - return true; + return workspace_.mutable_data() != nullptr; } #endif // LITE_WITH_ARM diff --git a/lite/core/device_info.h b/lite/core/device_info.h index 26954341e3..81c0ac4bf9 100644 --- a/lite/core/device_info.h +++ b/lite/core/device_info.h @@ -73,7 +73,7 @@ class DeviceInfo { T* workspace_data() { return reinterpret_cast(workspace_.mutable_data()); } - bool ExtendWorkspace(int size); + bool ExtendWorkspace(size_t size); private: int core_num_; diff --git a/lite/kernels/arm/conv_direct.cc b/lite/kernels/arm/conv_direct.cc index 1595d12b99..ae8c1d1b9a 100644 --- a/lite/kernels/arm/conv_direct.cc +++ b/lite/kernels/arm/conv_direct.cc @@ -19,6 +19,25 @@ namespace lite { namespace kernels { namespace arm { +template <> +void DirectConv::ReInitWhenNeeded() { + auto& param = this->template Param(); + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + if (last_shape_ == x_dims) { + return; + } + auto& ctx = this->ctx_->template As(); + if (param.strides[0] == 2) { + ctx.ExtendWorkspace( + lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx)); + } else { + ctx.ExtendWorkspace( + lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx)); + } +} + template <> void DirectConv::Run() { auto& param = this->Param(); @@ -70,6 +89,9 @@ void DirectConv::Run() { } } +template <> +void DirectConv::ReInitWhenNeeded() {} + template <> void DirectConv::Run() { auto& param = this->Param(); @@ -126,6 +148,9 @@ void DirectConv::Run() { } } +template <> +void DirectConv::ReInitWhenNeeded() {} + template <> void DirectConv::Run() { auto& param = this->Param(); diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h index cd90c4d6c5..99025e2cec 100644 --- a/lite/kernels/arm/conv_direct.h +++ b/lite/kernels/arm/conv_direct.h @@ -178,10 +178,12 @@ class DirectConv : public KernelLite { w_scale_); } + virtual void ReInitWhenNeeded(); virtual void Run(); /// todo, support inplace weights transform protected: + DDim last_shape_; Tensor weights_; Tensor bias_; bool flag_trans_weights_{false}; -- GitLab