// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_impl.h" #include "lite/core/context.h" #include "lite/operators/op_params.h" #ifdef ARM_WITH_OMP #include #endif namespace paddle { namespace lite { namespace arm { namespace math { const int OUT_C_BLOCK = 4; const int OUT_H_BLOCK = 2; const int OUT_W_BLOCK = 4; size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param, ARMContext* ctx) { auto dim_in = param.x->dims(); auto dim_out = param.output->dims(); const int threads = ctx->threads(); int llc_size = ctx->llc_size() / sizeof(float); const int pad_w = param.paddings[1]; const int pad_h = param.paddings[0]; int ow = dim_out[3]; int oh = dim_out[2]; int ic = dim_in[1]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round + 2; int hout_r_block = (llc_size - 2 * win_round * ic) / (win_round * ic + OUT_C_BLOCK * wout_round * threads); hout_r_block = hout_r_block > oh ? oh : hout_r_block; hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; const int hin_r_block = hout_r_block + 2; int in_len = win_round * ic; int pre_in_size = hin_r_block * in_len; int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size); } void conv_3x3s1_direct_fp32(const float* i_data, float* o_data, int bs, int oc, int oh, int ow, int ic, int ih, int win, const float* weights, const float* bias, const operators::ConvParam& param, ARMContext* ctx) { const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); const int pad_h = param.paddings[0]; const int pad_w = param.paddings[1]; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK); const int win_round = wout_round + 2; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; int hout_r_block = (l2_size - 2 * win_round * ic) / (win_round * ic + OUT_C_BLOCK * wout_round * threads); hout_r_block = hout_r_block > oh ? oh : hout_r_block; hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK; hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block; const int hin_r_block = hout_r_block + 2; float* tmp_work_space = ctx->workspace_data(); float ptr_zero[win_round]; // NOLINT memset(ptr_zero, 0, sizeof(float) * win_round); float ptr_write[wout_round]; // NOLINT int in_len = win_round * ic; int pre_in_size = hin_r_block * in_len; int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round; float* pre_din = tmp_work_space; int size_in_channel = win * ih; int size_out_channel = ow * oh; int w_stride = ic * 9; // kernel_w * kernel_h; int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h * int ws = -pad_w; int we = ws + win_round; int w_loop = wout_round / 4; int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK; int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK; int out_row_stride = OUT_C_BLOCK * wout_round; for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; for (int h = 0; h < oh; h += hout_r_block) { int h_kernel = hout_r_block; if (h + hout_r_block > oh) { h_kernel = oh - h; } int hs = h - pad_h; int he = hs + h_kernel + 2; prepack_input_nxw( din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); #pragma omp parallel for num_threads(threads) for (int c = 0; c < oc - (OUT_C_BLOCK - 1); c += OUT_C_BLOCK) { #ifdef ARM_WITH_OMP float* pre_out = pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; #else float* pre_out = pre_din + pre_in_size; #endif const float* block_inr0 = pre_din; const float* block_inr1 = block_inr0 + in_len; const float* block_inr2 = block_inr1 + in_len; const float* block_inr3 = block_inr2 + in_len; const float* weight_c = weights + c * w_stride; const float* bias_ptr = ptr_zero; if (flag_bias) { bias_ptr = bias + c; } fill_packed_biasc4( pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel); for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_c; const float* inr0 = block_inr0; const float* inr1 = block_inr1; const float* inr2 = block_inr2; const float* inr3 = block_inr3; float* pre_out0 = pre_out + hk * out_row_stride; float* pre_out1 = pre_out0 + out_row_stride; #ifdef __aarch64__ for (int i = 0; i < ic; ++i) { float* ptr_out0 = pre_out0; float* ptr_out1 = pre_out1; float32x4_t w0 = vld1q_f32(wc0); // w0, v23 float32x4_t w1 = vld1q_f32(wc0 + 4); // w1, v24 float32x4_t w2 = vld1q_f32(wc0 + 8); // w2, v25 float32x4_t w3 = vld1q_f32(wc0 + 12); // w3, v26 float32x4_t w4 = vld1q_f32(wc0 + 16); // w4, v27 float32x4_t w5 = vld1q_f32(wc0 + 20); // w5, v28 float32x4_t w6 = vld1q_f32(wc0 + 24); // w6, v29 float32x4_t w7 = vld1q_f32(wc0 + 28); // w7, v30 float32x4_t w8 = vld1q_f32(wc0 + 32); // w8, v31 const float* r0 = inr0; const float* r1 = inr1; const float* r2 = inr2; const float* r3 = inr3; int cnt = w_loop; // clang-format off asm volatile( "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00,outr01*/ "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ "2: \n" /* main loop*/ /* r0, r1, mul w0, get out r0, r1 */ "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ /* r0, r1, mul w1, get out r0, r1 */ "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ /* r0, r1, mul w2, get out r0, r1 */ "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ "fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ /* r1, r2, mul w3, get out r0, r1 */ "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ "fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ /* r1, r2, mul w4, get out r0, r1 */ "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ "fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ "fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ /* r1, r2, mul w5, get out r0, r1 */ "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ "fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ "fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ /* r2, r3, mul w6, get out r0, r1 */ "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ "ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ /* r2, r3, mul w7, get out r0, r1 */ "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ /* r2, r3, mul w8, get out r0, r1 */ "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ "bne 2b \n" /* jump to main loop*/ : [cnt] "+r"(cnt), [r0] "+r"(r0),[r1] "+r"(r1), [r2] "+r"(r2),[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) : "cc","memory","v0","v1","v2","v3", "v4","v5","v6","v7","v15","v16", "v17","v18","v19","v20","v21","v22" ); // clang-format on wc0 += 9 * OUT_C_BLOCK; inr0 += win_round; inr1 += win_round; inr2 += win_round; inr3 += win_round; } #else // not __aarch64__ for (int i = 0; i < ic; ++i) { const float* wc0 = weight_c + i * w_stride_chin; float* ptr_out0 = pre_out0; float* ptr_out1 = pre_out1; const float* r0 = inr0; const float* r1 = inr1; const float* r2 = inr2; const float* r3 = inr3; int cnt = w_loop; // clang-format off asm volatile( "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" /* load weights */ "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" /* load r0, r1 */ "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" "vld1.32 {d2}, [%[r0]] @ load r0\n" "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" /* main loop */ "0: @ main loop\n" /* mul r0 with w0, w1, w2, get out r0 */ "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n" "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n" "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n" "vmla.f32 q9, q5, d0[1] @ w0 * inr01\n" "vmla.f32 q10, q5, d1[0] @ w0 * inr02\n" "vmla.f32 q11, q5, d1[1] @ w0 * inr03\n" "vld1.32 {d3-d4}, [%[r1]]! @ load r1\n" "vmla.f32 q8, q6, d0[1] @ w1 * inr01\n" "vmla.f32 q9, q6, d1[0] @ w1 * inr02\n" "vmla.f32 q10, q6, d1[1] @ w1 * inr03\n" "vmla.f32 q11, q6, d2[0] @ w1 * inr04\n" "vld1.32 {d5}, [%[r1]] @ load r0\n" "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n" "vmla.f32 q9, q7, d1[1] @ w2 * inr03\n" "vmla.f32 q10, q7, d2[0] @ w2 * inr04\n" "vmla.f32 q11, q7, d2[1] @ w2 * inr05\n" "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n" /* mul r1 with w0, w1, w2, get out r1 */ "vmla.f32 q12, q5, d3[0] @ w0 * inr10\n" "vmla.f32 q13, q5, d3[1] @ w0 * inr11\n" "vmla.f32 q14, q5, d4[0] @ w0 * inr12\n" "vmla.f32 q15, q5, d4[1] @ w0 * inr13\n" "vmla.f32 q12, q6, d3[1] @ w1 * inr11\n" "vmla.f32 q13, q6, d4[0] @ w1 * inr12\n" "vmla.f32 q14, q6, d4[1] @ w1 * inr13\n" "vmla.f32 q15, q6, d5[0] @ w1 * inr14\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4\n" "vmla.f32 q12, q7, d4[0] @ w2 * inr12\n" "vmla.f32 q13, q7, d4[1] @ w2 * inr13\n" "vmla.f32 q14, q7, d5[0] @ w2 * inr14\n" "vmla.f32 q15, q7, d5[1] @ w2 * inr15\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w5\n" /* mul r1 with w3, w4, w5, get out r0 */ "vmla.f32 q8, q5, d3[0] @ w3 * inr10\n" "vmla.f32 q9, q5, d3[1] @ w3 * inr11\n" "vmla.f32 q10, q5, d4[0] @ w3 * inr12\n" "vmla.f32 q11, q5, d4[1] @ w3 * inr13\n" "vld1.32 {d0-d1}, [%[r2]]! @ load r2\n" "vmla.f32 q8, q6, d3[1] @ w4 * inr11\n" "vmla.f32 q9, q6, d4[0] @ w4 * inr12\n" "vmla.f32 q10, q6, d4[1] @ w4 * inr13\n" "vmla.f32 q11, q6, d5[0] @ w4 * inr14\n" "vld1.32 {d2}, [%[r2]] @ load r2\n" "vmla.f32 q8, q7, d4[0] @ w5 * inr12\n" "vmla.f32 q9, q7, d4[1] @ w5 * inr13\n" "vmla.f32 q10, q7, d5[0] @ w5 * inr14\n" "vmla.f32 q11, q7, d5[1] @ w5 * inr15\n" /* mul r2 with w3, w4, w5, get out r1 */ "vmla.f32 q12, q5, d0[0] @ w3 * inr20\n" "vmla.f32 q13, q5, d0[1] @ w3 * inr21\n" "vmla.f32 q14, q5, d1[0] @ w3 * inr22\n" "vmla.f32 q15, q5, d1[1] @ w3 * inr23\n" "vmla.f32 q12, q6, d0[1] @ w4 * inr21\n" "vmla.f32 q13, q6, d1[0] @ w4 * inr22\n" "vmla.f32 q14, q6, d1[1] @ w4 * inr23\n" "vmla.f32 q15, q6, d2[0] @ w4 * inr24\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n" "vmla.f32 q12, q7, d1[0] @ w5 * inr22\n" "vmla.f32 q13, q7, d1[1] @ w5 * inr23\n" "vmla.f32 q14, q7, d2[0] @ w5 * inr24\n" "vmla.f32 q15, q7, d2[1] @ w5 * inr25\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n" "sub %[wc0], %[wc0], #144 @ wc0 - 144\n" /* mul r2 with w6, w7, w8, get out r0 */ "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n" "vmla.f32 q9, q5, d0[1] @ w6 * inr21\n" "vld1.32 {d3-d4}, [%[r3]]! @ load r3\n" "vmla.f32 q10, q5, d1[0] @ w6 * inr22\n" "vmla.f32 q11, q5, d1[1] @ w6 * inr23\n" "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n" "vmla.f32 q9, q6, d1[0] @ w7 * inr22\n" "vld1.32 {d5}, [%[r3]] @ load r3\n" "vmla.f32 q10, q6, d1[1] @ w7 * inr23\n" "vmla.f32 q11, q6, d2[0] @ w7 * inr24\n" "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n" "vmla.f32 q9, q7, d1[1] @ w8 * inr23\n" "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n" "vmla.f32 q10, q7, d2[0] @ w8 * inr24\n" "vmla.f32 q11, q7, d2[1] @ w8 * inr25\n" "vld1.32 {d2}, [%[r0]] @ load r0\n" /* mul r3 with w6, w7, w8, get out r1 */ "vmla.f32 q12, q5, d3[0] @ w6 * inr20\n" "vmla.f32 q13, q5, d3[1] @ w6 * inr21\n" "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n" "vmla.f32 q14, q5, d4[0] @ w6 * inr22\n" "vmla.f32 q15, q5, d4[1] @ w6 * inr23\n" "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n" "vmla.f32 q12, q6, d3[1] @ w7 * inr21\n" "vmla.f32 q13, q6, d4[0] @ w7 * inr22\n" "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n" "vmla.f32 q14, q6, d4[1] @ w7 * inr23\n" "vmla.f32 q15, q6, d5[0] @ w7 * inr24\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n" "vmla.f32 q12, q7, d4[0] @ w8 * inr22\n" "vmla.f32 q13, q7, d4[1] @ w8 * inr23\n" "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n" "vmla.f32 q14, q7, d5[0] @ w8 * inr24\n" "vmla.f32 q15, q7, d5[1] @ w8 * inr25\n" "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n" "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n" "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n" "subs %[cnt], #1 @ loop count--\n" "bne 0b @ jump to main loop\n" : [cnt] "+r"(cnt), [r0] "+r"(r0),[r1] "+r"(r1), [r2] "+r"(r2),[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1), [wc0] "+r"(wc0) : : "cc","memory","q0","q1","q2","q3", "q4","q5","q6","q7","q8","q9", "q10","q11","q12","q13","q14","q15"); // clang-format on inr0 += win_round; inr1 += win_round; inr2 += win_round; inr3 += win_round; } #endif // __aarch64__ block_inr0 = block_inr2; block_inr1 = block_inr3; block_inr2 = block_inr1 + in_len; block_inr3 = block_inr2 + in_len; } write_to_output_c4_fp32(pre_out, dout_batch, c, c + OUT_C_BLOCK, h, h + h_kernel, 0, wout_round, oc, oh, ow, flag_relu, ptr_write); } const float* weight_remain_ptr = weights + c_round_down * w_stride; #pragma omp parallel for num_threads(threads) for (int c = 0; c < c_remain; ++c) { #ifdef ARM_WITH_OMP float* pre_out = pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; #else float* pre_out = pre_din + pre_in_size; #endif int c_idx = c_round_down + c; int h_kernel = hout_r_block; if (h + hout_r_block > oh) { h_kernel = oh - h; } const float* block_inr0 = pre_din; const float* block_inr1 = block_inr0 + in_len; const float* block_inr2 = block_inr1 + in_len; const float* block_inr3 = block_inr2 + in_len; const float* bias_ptr = ptr_zero; if (flag_bias) { bias_ptr = bias + c_idx; } fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) { const float* wc0 = weight_remain_ptr; const float* inr0 = block_inr0; const float* inr1 = block_inr1; const float* inr2 = block_inr2; const float* inr3 = block_inr3; float* pre_out0 = pre_out + hk * wout_round; float* pre_out1 = pre_out0 + wout_round; #ifdef __aarch64__ for (int i = 0; i < ic; ++i) { float* ptr_out0 = pre_out0; float* ptr_out1 = pre_out1; float32x4_t w0 = vdupq_n_f32(wc0[c]); // w0, v23 float32x4_t w1 = vdupq_n_f32(wc0[4 + c]); // w1, v24 float32x4_t w2 = vdupq_n_f32(wc0[8 + c]); // w2, v25 float32x4_t w3 = vdupq_n_f32(wc0[12 + c]); // w3, v26 float32x4_t w4 = vdupq_n_f32(wc0[16 + c]); // w4, v27 float32x4_t w5 = vdupq_n_f32(wc0[20 + c]); // w5, v28 float32x4_t w6 = vdupq_n_f32(wc0[24 + c]); // w6, v29 float32x4_t w7 = vdupq_n_f32(wc0[28 + c]); // w7, v30 float32x4_t w8 = vdupq_n_f32(wc0[32 + c]); // w8, v31 const float* r0 = inr0; const float* r1 = inr1; const float* r2 = inr2; const float* r3 = inr3; int cnt = w_loop; // clang-format off asm volatile( "ldr q21, [%[ptr_out0]]\n" /* load outr0, w0~w3*/ "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ "2: \n" /* main loop*/ "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ "fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ "ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ "fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ "fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ "ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ "fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ "fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ "fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ "fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ "str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ "ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/ "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ "bne 2b \n" /* jump to main loop*/ : [cnt] "+r"(cnt), [r0] "+r"(r0),[r1] "+r"(r1), [r2] "+r"(r2),[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2), [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5), [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8) : "cc","memory","v0", "v1","v2","v3","v4","v5","v6", "v7","v8","v9","v10","v11","v12", "v13","v14","v15","v21","v22" ); // clang-format on wc0 += 9 * OUT_C_BLOCK; inr0 += win_round; inr1 += win_round; inr2 += win_round; inr3 += win_round; } #else // not __aarch64__ for (int i = 0; i < ic; ++i) { float* ptr_out0 = pre_out0; float* ptr_out1 = pre_out1; //! get valid weights of current output channel float w_tmp[10] = {wc0[c], wc0[c + 4], wc0[c + 8], wc0[c + 12], wc0[c + 16], wc0[c + 20], wc0[c + 24], wc0[c + 28], wc0[c + 32], 0.f}; float32x4_t w0 = vld1q_f32(w_tmp); // w0, w1, w2, q0 float32x4_t w1 = vld1q_f32(w_tmp + 3); // w3, w4, w5, q1 float32x4_t w2 = vld1q_f32(w_tmp + 6); // w6, w7, w8, q2 const float* r0 = inr0; const float* r1 = inr1; const float* r2 = inr2; const float* r3 = inr3; int cnt = w_loop / 2; if (cnt > 0) { // clang-format off asm volatile( "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n" "vld1.32 {d6-d9}, [%[r0]]! @ load r0\n" "vld1.32 {d10}, [%[r0]] @ load r0\n" /* main loop */ "0: @ main loop\n" /* r0 * w0, w1, w2, get out r0*/ "vld1.32 {d28-d31}, [%[ptr_out1]]@ load or10 or11\n" "vext.32 q8, q3, q4, #1 @ r0, shift left 1\n" "vext.32 q9, q4, q5, #1 @ r0, shift left 1\n" "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0\n" "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0\n" "vext.32 q10, q3, q4, #2 @ r0, shift left 2\n" "vext.32 q11, q4, q5, #2 @ r0, shift left 2\n" "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n" "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0\n" "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8\n" "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0\n" "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0\n" "vld1.32 {d10}, [%[r1]] @ load r1\n" /* r1 * w3, w4, w5, get out r0*/ /* r1 * w0, w1, w2, get out r1*/ "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1\n" "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1\n" "vext.32 q8, q3, q4, #1 @ r1, shift left 1\n" "vext.32 q9, q4, q5, #1 @ r1, shift left 1\n" "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1\n" "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1\n" "vext.32 q10, q3, q4, #2 @ r1, shift left 2\n" "vext.32 q11, q4, q5, #2 @ r1, shift left 2\n" "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n" "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1\n" "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1\n" "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1\n" "vld1.32 {d6-d9}, [%[r2]]! @ load r2\n" "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1\n" "vmla.f32 q13, q11, %f[w1][0] @ w12 * r1\n" "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1\n" "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1\n" "vld1.32 {d10}, [%[r2]] @ load r2\n" /* r2 * w6, w7, w8, get out r0*/ /* r2 * w3, w4, w5, get out r1*/ "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n" "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2\n" "vext.32 q8, q3, q4, #1 @ r2, shift left 1\n" "vext.32 q9, q4, q5, #1 @ r2, shift left 1\n" "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2\n" "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2\n" "vext.32 q10, q3, q4, #2 @ r2, shift left 2\n" "vext.32 q11, q4, q5, #2 @ r2, shift left 2\n" "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2\n" "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2\n" "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2\n" "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2\n" "vld1.32 {d6-d9}, [%[r3]]! @ load r3\n" "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2\n" "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2\n" "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2\n" "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2\n" "vld1.32 {d10}, [%[r3]] @ load r3\n" /* r3 * w6, w7, w8, get out r1*/ "vext.32 q8, q3, q4, #1 @ r3, shift left 1\n" "vext.32 q9, q4, q5, #1 @ r3, shift left 1\n" "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3\n" "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3\n" "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, or01\n" "vext.32 q10, q3, q4, #2 @ r3, shift left 2\n" "vext.32 q11, q4, q5, #2 @ r3, shift left 2\n" "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3\n" "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3\n" "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00,or01\n" "vld1.32 {d6-d9}, [%[r0]]! @ load r3\n" "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3\n" "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3\n" "vld1.32 {d10}, [%[r0]] @ load r0\n" "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, or11\n" "subs %[cnt], #1 @ loop count -1\n" "bne 0b @ jump to main loop\n" : [cnt] "+r"(cnt), [r0] "+r"(r0),[r1] "+r"(r1), [r2] "+r"(r2),[r3] "+r"(r3), [ptr_out0] "+r"(ptr_out0), [ptr_out1] "+r"(ptr_out1) : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2) : "cc","memory","q3","q4", "q5","q6","q7","q8","q9","q10", "q11","q12","q13","q14","q15" ); // clang-format on r0 -= 8; } //! deal with remain ow if (w_loop & 1) { ptr_out0[0] += r0[0] * w_tmp[0] + r0[1] * w_tmp[1] + r0[2] * w_tmp[2] + r1[0] * w_tmp[3] + r1[1] * w_tmp[4] + r1[2] * w_tmp[5] + r2[0] * w_tmp[6] + r2[1] * w_tmp[7] + r2[2] * w_tmp[8]; ptr_out0[1] += r0[1] * w_tmp[0] + r0[2] * w_tmp[1] + r0[3] * w_tmp[2] + r1[1] * w_tmp[3] + r1[2] * w_tmp[4] + r1[3] * w_tmp[5] + r2[1] * w_tmp[6] + r2[2] * w_tmp[7] + r2[3] * w_tmp[8]; ptr_out0[2] += r0[2] * w_tmp[0] + r0[3] * w_tmp[1] + r0[4] * w_tmp[2] + r1[2] * w_tmp[3] + r1[3] * w_tmp[4] + r1[4] * w_tmp[5] + r2[2] * w_tmp[6] + r2[3] * w_tmp[7] + r2[4] * w_tmp[8]; ptr_out0[3] += r0[3] * w_tmp[0] + r0[4] * w_tmp[1] + r0[5] * w_tmp[2] + r1[3] * w_tmp[3] + r1[4] * w_tmp[4] + r1[5] * w_tmp[5] + r2[3] * w_tmp[6] + r2[4] * w_tmp[7] + r2[5] * w_tmp[8]; ptr_out1[0] += r1[0] * w_tmp[0] + r1[1] * w_tmp[1] + r1[2] * w_tmp[2] + r2[0] * w_tmp[3] + r2[1] * w_tmp[4] + r2[2] * w_tmp[5] + r3[0] * w_tmp[6] + r3[1] * w_tmp[7] + r3[2] * w_tmp[8]; ptr_out1[1] += r1[1] * w_tmp[0] + r1[2] * w_tmp[1] + r1[3] * w_tmp[2] + r2[1] * w_tmp[3] + r2[2] * w_tmp[4] + r2[3] * w_tmp[5] + r3[1] * w_tmp[6] + r3[2] * w_tmp[7] + r3[3] * w_tmp[8]; ptr_out1[2] += r1[2] * w_tmp[0] + r1[3] * w_tmp[1] + r1[4] * w_tmp[2] + r2[2] * w_tmp[3] + r2[3] * w_tmp[4] + r2[4] * w_tmp[5] + r3[2] * w_tmp[6] + r3[3] * w_tmp[7] + r3[4] * w_tmp[8]; ptr_out1[3] += r1[3] * w_tmp[0] + r1[4] * w_tmp[1] + r1[5] * w_tmp[2] + r2[3] * w_tmp[3] + r2[4] * w_tmp[4] + r2[5] * w_tmp[5] + r3[3] * w_tmp[6] + r3[4] * w_tmp[7] + r3[5] * w_tmp[8]; } wc0 += 36; inr0 += win_round; inr1 += win_round; inr2 += win_round; inr3 += win_round; } #endif // __aarch64__ block_inr0 = block_inr2; block_inr1 = block_inr3; block_inr2 = block_inr1 + in_len; block_inr3 = block_inr2 + in_len; } write_to_output_c1_fp32(pre_out, dout_batch, c_idx, c_idx + 1, h, h + h_kernel, 0, wout_round, oc, oh, ow, flag_relu, ptr_write); } } } } } // namespace math } // namespace arm } // namespace lite } // namespace paddle