提交 61647c35 编写于 作者: X Xiaoyang LI 提交者: Yan Chunwei

add workspace compute funcs for direct conv, test=develop (#2132)

上级 e122b4be
......@@ -26,6 +26,39 @@ namespace lite {
namespace arm {
namespace math {
const int OUT_C_BLOCK = 4;
const int OUT_H_BLOCK = 2;
const int OUT_W_BLOCK = 4;
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx) {
auto dim_in = param.x->dims();
auto dim_out = param.output->dims();
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0];
int ow = dim_out[3];
int oh = dim_out[2];
int ic = dim_in[1];
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round + 2;
int hout_r_block = (llc_size - 2 * win_round * ic) /
(win_round * ic + OUT_C_BLOCK * wout_round * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block + 2;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size);
}
void conv_3x3s1_direct_fp32(const float* i_data,
float* o_data,
int bs,
......@@ -44,19 +77,16 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
const int hout_c_block = 4;
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block;
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round + 2;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
int hout_r_block = (l2_size - 2 * win_round * ic) /
(win_round * ic + hout_c_block * wout_round * threads);
(win_round * ic + OUT_C_BLOCK * wout_round * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block + 2;
......@@ -67,23 +97,23 @@ void conv_3x3s1_direct_fp32(const float* i_data,
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
float* pre_din = tmp_work_space;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int w_stride = ic * 9; // kernel_w * kernel_h;
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h *
int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h *
int ws = -pad_w;
int we = ws + win_round;
int w_loop = wout_round / 4;
int c_remain = oc - (oc / hout_c_block) * hout_c_block;
int c_round_down = (oc / hout_c_block) * hout_c_block;
int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int out_row_stride = hout_c_block * wout_round;
int out_row_stride = OUT_C_BLOCK * wout_round;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
......@@ -97,7 +127,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
prepack_input_nxw(
din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero);
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc - (hout_c_block - 1); c += hout_c_block) {
for (int c = 0; c < oc - (OUT_C_BLOCK - 1); c += OUT_C_BLOCK) {
#ifdef ARM_WITH_OMP
float* pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
......@@ -115,9 +145,9 @@ void conv_3x3s1_direct_fp32(const float* i_data,
bias_ptr = bias + c;
}
fill_packed_biasc4(
pre_out, bias_ptr, wout_round * hout_c_block * h_kernel);
pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c;
const float* inr0 = block_inr0;
......@@ -148,9 +178,9 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3;
int cnt = w_loop;
// clang-format off
asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00,
outr01*/
"ldp q15, q16, [%[ptr_out0]]\n" /* load outr00,outr01*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/
......@@ -166,7 +196,6 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/
"fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/
"fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/
/* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
"fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/
......@@ -176,9 +205,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/
"fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/
"fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
/* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/
......@@ -188,7 +215,6 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/
"fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/
"fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/
/* r1, r2, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/
......@@ -198,9 +224,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/
"fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/
"fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/
"ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/
/* r1, r2, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
"fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/
......@@ -210,9 +234,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/
"fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
/* r1, r2, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/
......@@ -222,7 +244,6 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/
"fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/
"fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/
/* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/
......@@ -232,9 +253,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/
"ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/
/* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/
......@@ -244,15 +263,12 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
/* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
......@@ -266,43 +282,21 @@ void conv_3x3s1_direct_fp32(const float* i_data,
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
wc0 += 9 * hout_c_block;
: [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
[w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
[w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
: "cc","memory","v0","v1","v2","v3",
"v4","v5","v6","v7","v15","v16",
"v17","v18","v19","v20","v21","v22"
);
// clang-format on
wc0 += 9 * OUT_C_BLOCK;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
......@@ -321,273 +315,135 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3;
int cnt = w_loop;
// clang-format off
asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ "
"load outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
/* load r0, r1 */
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, "
"4 float\n"
"vld1.32 {d2}, [%[r0]] @ load r0, "
"2 float\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"vld1.32 {d0-d1}, [%[r0]]! @ load r0\n"
"vld1.32 {d2}, [%[r0]] @ load r0\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
/* main loop */
"0: @ main "
"loop\n"
"0: @ main loop\n"
/* mul r0 with w0, w1, w2, get out r0 */
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load "
"outr1, w0, w1, c0~c3\n"
"vmla.f32 q8, q5, d0[0] @ w0 * "
"inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load "
"outr1, w2, w3, c0~c3\n"
"vmla.f32 q9, q5, d0[1] @ w0 * "
"inr01\n"
"vmla.f32 q10, q5, d1[0] @ w0 * "
"inr02\n"
"vmla.f32 q11, q5, d1[1] @ w0 * "
"inr03\n"
"vld1.32 {d3-d4}, [%[r1]]! @ load r1, "
"4 float\n"
"vmla.f32 q8, q6, d0[1] @ w1 * "
"inr01\n"
"vmla.f32 q9, q6, d1[0] @ w1 * "
"inr02\n"
"vmla.f32 q10, q6, d1[1] @ w1 * "
"inr03\n"
"vmla.f32 q11, q6, d2[0] @ w1 * "
"inr04\n"
"vld1.32 {d5}, [%[r1]] @ load r0, "
"2 float\n"
"vmla.f32 q8, q7, d1[0] @ w2 * "
"inr02\n"
"vmla.f32 q9, q7, d1[1] @ w2 * "
"inr03\n"
"vmla.f32 q10, q7, d2[0] @ w2 * "
"inr04\n"
"vmla.f32 q11, q7, d2[1] @ w2 * "
"inr05\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 "
"- 32, to start address\n"
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n"
"vmla.f32 q8, q5, d0[0] @ w0 * inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n"
"vmla.f32 q9, q5, d0[1] @ w0 * inr01\n"
"vmla.f32 q10, q5, d1[0] @ w0 * inr02\n"
"vmla.f32 q11, q5, d1[1] @ w0 * inr03\n"
"vld1.32 {d3-d4}, [%[r1]]! @ load r1\n"
"vmla.f32 q8, q6, d0[1] @ w1 * inr01\n"
"vmla.f32 q9, q6, d1[0] @ w1 * inr02\n"
"vmla.f32 q10, q6, d1[1] @ w1 * inr03\n"
"vmla.f32 q11, q6, d2[0] @ w1 * inr04\n"
"vld1.32 {d5}, [%[r1]] @ load r0\n"
"vmla.f32 q8, q7, d1[0] @ w2 * inr02\n"
"vmla.f32 q9, q7, d1[1] @ w2 * inr03\n"
"vmla.f32 q10, q7, d2[0] @ w2 * inr04\n"
"vmla.f32 q11, q7, d2[1] @ w2 * inr05\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n"
/* mul r1 with w0, w1, w2, get out r1 */
"vmla.f32 q12, q5, d3[0] @ w0 * "
"inr10\n"
"vmla.f32 q13, q5, d3[1] @ w0 * "
"inr11\n"
"vmla.f32 q14, q5, d4[0] @ w0 * "
"inr12\n"
"vmla.f32 q15, q5, d4[1] @ w0 * "
"inr13\n"
"vmla.f32 q12, q6, d3[1] @ w1 * "
"inr11\n"
"vmla.f32 q13, q6, d4[0] @ w1 * "
"inr12\n"
"vmla.f32 q14, q6, d4[1] @ w1 * "
"inr13\n"
"vmla.f32 q15, q6, d5[0] @ w1 * "
"inr14\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, "
"w4, to q5, q6\n"
"vmla.f32 q12, q7, d4[0] @ w2 * "
"inr12\n"
"vmla.f32 q13, q7, d4[1] @ w2 * "
"inr13\n"
"vmla.f32 q14, q7, d5[0] @ w2 * "
"inr14\n"
"vmla.f32 q15, q7, d5[1] @ w2 * "
"inr15\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, "
"to q7\n"
"vmla.f32 q12, q5, d3[0] @ w0 * inr10\n"
"vmla.f32 q13, q5, d3[1] @ w0 * inr11\n"
"vmla.f32 q14, q5, d4[0] @ w0 * inr12\n"
"vmla.f32 q15, q5, d4[1] @ w0 * inr13\n"
"vmla.f32 q12, q6, d3[1] @ w1 * inr11\n"
"vmla.f32 q13, q6, d4[0] @ w1 * inr12\n"
"vmla.f32 q14, q6, d4[1] @ w1 * inr13\n"
"vmla.f32 q15, q6, d5[0] @ w1 * inr14\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4\n"
"vmla.f32 q12, q7, d4[0] @ w2 * inr12\n"
"vmla.f32 q13, q7, d4[1] @ w2 * inr13\n"
"vmla.f32 q14, q7, d5[0] @ w2 * inr14\n"
"vmla.f32 q15, q7, d5[1] @ w2 * inr15\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5\n"
/* mul r1 with w3, w4, w5, get out r0 */
"vmla.f32 q8, q5, d3[0] @ w3 * "
"inr10\n"
"vmla.f32 q9, q5, d3[1] @ w3 * "
"inr11\n"
"vmla.f32 q10, q5, d4[0] @ w3 * "
"inr12\n"
"vmla.f32 q11, q5, d4[1] @ w3 * "
"inr13\n"
"vld1.32 {d0-d1}, [%[r2]]! @ load r2, "
"4 float\n"
"vmla.f32 q8, q6, d3[1] @ w4 * "
"inr11\n"
"vmla.f32 q9, q6, d4[0] @ w4 * "
"inr12\n"
"vmla.f32 q10, q6, d4[1] @ w4 * "
"inr13\n"
"vmla.f32 q11, q6, d5[0] @ w4 * "
"inr14\n"
"vld1.32 {d2}, [%[r2]] @ load r2, "
"2 float\n"
"vmla.f32 q8, q7, d4[0] @ w5 * "
"inr12\n"
"vmla.f32 q9, q7, d4[1] @ w5 * "
"inr13\n"
"vmla.f32 q10, q7, d5[0] @ w5 * "
"inr14\n"
"vmla.f32 q11, q7, d5[1] @ w5 * "
"inr15\n"
"vmla.f32 q8, q5, d3[0] @ w3 * inr10\n"
"vmla.f32 q9, q5, d3[1] @ w3 * inr11\n"
"vmla.f32 q10, q5, d4[0] @ w3 * inr12\n"
"vmla.f32 q11, q5, d4[1] @ w3 * inr13\n"
"vld1.32 {d0-d1}, [%[r2]]! @ load r2\n"
"vmla.f32 q8, q6, d3[1] @ w4 * inr11\n"
"vmla.f32 q9, q6, d4[0] @ w4 * inr12\n"
"vmla.f32 q10, q6, d4[1] @ w4 * inr13\n"
"vmla.f32 q11, q6, d5[0] @ w4 * inr14\n"
"vld1.32 {d2}, [%[r2]] @ load r2\n"
"vmla.f32 q8, q7, d4[0] @ w5 * inr12\n"
"vmla.f32 q9, q7, d4[1] @ w5 * inr13\n"
"vmla.f32 q10, q7, d5[0] @ w5 * inr14\n"
"vmla.f32 q11, q7, d5[1] @ w5 * inr15\n"
/* mul r2 with w3, w4, w5, get out r1 */
"vmla.f32 q12, q5, d0[0] @ w3 * "
"inr20\n"
"vmla.f32 q13, q5, d0[1] @ w3 * "
"inr21\n"
"vmla.f32 q14, q5, d1[0] @ w3 * "
"inr22\n"
"vmla.f32 q15, q5, d1[1] @ w3 * "
"inr23\n"
"vmla.f32 q12, q6, d0[1] @ w4 * "
"inr21\n"
"vmla.f32 q13, q6, d1[0] @ w4 * "
"inr22\n"
"vmla.f32 q14, q6, d1[1] @ w4 * "
"inr23\n"
"vmla.f32 q15, q6, d2[0] @ w4 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d1[0] @ w5 * "
"inr22\n"
"vmla.f32 q13, q7, d1[1] @ w5 * "
"inr23\n"
"vmla.f32 q14, q7, d2[0] @ w5 * "
"inr24\n"
"vmla.f32 q15, q7, d2[1] @ w5 * "
"inr25\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
"vmla.f32 q12, q5, d0[0] @ w3 * inr20\n"
"vmla.f32 q13, q5, d0[1] @ w3 * inr21\n"
"vmla.f32 q14, q5, d1[0] @ w3 * inr22\n"
"vmla.f32 q15, q5, d1[1] @ w3 * inr23\n"
"vmla.f32 q12, q6, d0[1] @ w4 * inr21\n"
"vmla.f32 q13, q6, d1[0] @ w4 * inr22\n"
"vmla.f32 q14, q6, d1[1] @ w4 * inr23\n"
"vmla.f32 q15, q6, d2[0] @ w4 * inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n"
"vmla.f32 q12, q7, d1[0] @ w5 * inr22\n"
"vmla.f32 q13, q7, d1[1] @ w5 * inr23\n"
"vmla.f32 q14, q7, d2[0] @ w5 * inr24\n"
"vmla.f32 q15, q7, d2[1] @ w5 * inr25\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144\n"
/* mul r2 with w6, w7, w8, get out r0 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d0[1] @ w6 * "
"inr21\n"
"vld1.32 {d3-d4}, [%[r3]]! @ load r3, "
"4 float\n"
"vmla.f32 q10, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q11, q5, d1[1] @ w6 * "
"inr23\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[0] @ w7 * "
"inr22\n"
"vld1.32 {d5}, [%[r3]] @ load r3, "
"2 float\n"
"vmla.f32 q10, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q11, q6, d2[0] @ w7 * "
"inr24\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d1[1] @ w8 * "
"inr23\n"
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, "
"4 float\n"
"vmla.f32 q10, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q11, q7, d2[1] @ w8 * "
"inr25\n"
"vld1.32 {d2}, [%[r0]] @ load r0, "
"2 float\n"
"vmla.f32 q8, q5, d0[0] @ w6 * inr20\n"
"vmla.f32 q9, q5, d0[1] @ w6 * inr21\n"
"vld1.32 {d3-d4}, [%[r3]]! @ load r3\n"
"vmla.f32 q10, q5, d1[0] @ w6 * inr22\n"
"vmla.f32 q11, q5, d1[1] @ w6 * inr23\n"
"vmla.f32 q8, q6, d0[1] @ w7 * inr21\n"
"vmla.f32 q9, q6, d1[0] @ w7 * inr22\n"
"vld1.32 {d5}, [%[r3]] @ load r3\n"
"vmla.f32 q10, q6, d1[1] @ w7 * inr23\n"
"vmla.f32 q11, q6, d2[0] @ w7 * inr24\n"
"vmla.f32 q8, q7, d1[0] @ w8 * inr22\n"
"vmla.f32 q9, q7, d1[1] @ w8 * inr23\n"
"vld1.32 {d0-d1}, [%[r0]]! @ load r0\n"
"vmla.f32 q10, q7, d2[0] @ w8 * inr24\n"
"vmla.f32 q11, q7, d2[1] @ w8 * inr25\n"
"vld1.32 {d2}, [%[r0]] @ load r0\n"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q5, d3[0] @ w6 * "
"inr20\n"
"vmla.f32 q13, q5, d3[1] @ w6 * "
"inr21\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q14, q5, d4[0] @ w6 * "
"inr22\n"
"vmla.f32 q15, q5, d4[1] @ w6 * "
"inr23\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q12, q6, d3[1] @ w7 * "
"inr21\n"
"vmla.f32 q13, q6, d4[0] @ w7 * "
"inr22\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vmla.f32 q14, q6, d4[1] @ w7 * "
"inr23\n"
"vmla.f32 q15, q6, d5[0] @ w7 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d4[0] @ w8 * "
"inr22\n"
"vmla.f32 q13, q7, d4[1] @ w8 * "
"inr23\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"vmla.f32 q14, q7, d5[0] @ w8 * "
"inr24\n"
"vmla.f32 q15, q7, d5[1] @ w8 * "
"inr25\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
"vmla.f32 q12, q5, d3[0] @ w6 * inr20\n"
"vmla.f32 q13, q5, d3[1] @ w6 * inr21\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n"
"vmla.f32 q14, q5, d4[0] @ w6 * inr22\n"
"vmla.f32 q15, q5, d4[1] @ w6 * inr23\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n"
"vmla.f32 q12, q6, d3[1] @ w7 * inr21\n"
"vmla.f32 q13, q6, d4[0] @ w7 * inr22\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"vmla.f32 q14, q6, d4[1] @ w7 * inr23\n"
"vmla.f32 q15, q6, d5[0] @ w7 * inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vmla.f32 q12, q7, d4[0] @ w8 * inr22\n"
"vmla.f32 q13, q7, d4[1] @ w8 * inr23\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"vmla.f32 q14, q7, d5[0] @ w8 * inr24\n"
"vmla.f32 q15, q7, d5[1] @ w8 * inr25\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
"subs %[cnt], #1 @ loop count--\n"
"bne 0b @ jump to main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
[wc0] "+r"(wc0)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
: "cc","memory","q0","q1","q2","q3",
"q4","q5","q6","q7","q8","q9",
"q10","q11","q12","q13","q14","q15");
// clang-format on
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
......@@ -602,7 +458,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
write_to_output_c4_fp32(pre_out,
dout_batch,
c,
c + hout_c_block,
c + OUT_C_BLOCK,
h,
h + h_kernel,
0,
......@@ -641,7 +497,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
}
fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_remain_ptr;
const float* inr0 = block_inr0;
......@@ -672,109 +528,66 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3;
int cnt = w_loop;
// clang-format off
asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr0,
w0~w3*/
"ldr q21, [%[ptr_out0]]\n" /* load outr0, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"2: \n" /* main loop*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/
"fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/
"ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/
"ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/
"ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/
"ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/
"fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/
"fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/
"fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/
"fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/
"ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/
"ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/
"ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/
"ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/
"fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/
"fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/
"fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/
"fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/
"fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/
"fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/
"str q21, [%[ptr_out0]], #16 \n" /*write output r0*/
"str q22, [%[ptr_out1]], #16 \n" /*write output r1*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v21",
"v22");
wc0 += 9 * hout_c_block;
: [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
[w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
[w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
: "cc","memory","v0",
"v1","v2","v3","v4","v5","v6",
"v7","v8","v9","v10","v11","v12",
"v13","v14","v15","v21","v22"
);
// clang-format on
wc0 += 9 * OUT_C_BLOCK;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
......@@ -806,181 +619,96 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3;
int cnt = w_loop / 2;
if (cnt > 0) {
// clang-format off
asm volatile(
"vld1.32 {d24-d27}, [%[ptr_out0]] @ "
"load or00, or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 "
"float\n"
"vld1.32 {d10}, [%[r0]] @ load r0, 2 "
"float\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r0\n"
"vld1.32 {d10}, [%[r0]] @ load r0\n"
/* main loop */
"0: @ main loop\n"
/* r0 * w0, w1, w2, get out r0*/
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, "
"or11\n"
"vext.32 q8, q3, q4, #1 @ r0, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r0, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r0, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r0, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r1]] @ load r1, 2 "
"float\n"
"vld1.32 {d28-d31}, [%[ptr_out1]]@ load or10 or11\n"
"vext.32 q8, q3, q4, #1 @ r0, shift left 1\n"
"vext.32 q9, q4, q5, #1 @ r0, shift left 1\n"
"vmla.f32 q12, q3, %e[w0][0] @ w00 * r0\n"
"vmla.f32 q13, q4, %e[w0][0] @ w00 * r0\n"
"vext.32 q10, q3, q4, #2 @ r0, shift left 2\n"
"vext.32 q11, q4, q5, #2 @ r0, shift left 2\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n"
"vmla.f32 q13, q9, %e[w0][1] @ w01 * r0\n"
"vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8\n"
"vmla.f32 q12, q10, %f[w0][0] @ w02 * r0\n"
"vmla.f32 q13, q11, %f[w0][0] @ w02 * r0\n"
"vld1.32 {d10}, [%[r1]] @ load r1\n"
/* r1 * w3, w4, w5, get out r0*/
/* r1 * w0, w1, w2, get out r1*/
"vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, "
"4, 5, 6, 7\n"
"vext.32 q8, q3, q4, #1 @ r1, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r1, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r1, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r1, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, "
"5, 6, 7, 8\n"
"vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, "
"1, 2, 3, 4\n"
"vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, "
"6, 7, 8, 9\n"
"vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r2]] @ load r2, 2 "
"float\n"
"vmla.f32 q12, q3, %e[w1][0] @ w10 * r1\n"
"vmla.f32 q13, q4, %e[w1][0] @ w10 * r1\n"
"vext.32 q8, q3, q4, #1 @ r1, shift left 1\n"
"vext.32 q9, q4, q5, #1 @ r1, shift left 1\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r1\n"
"vmla.f32 q15, q4, %e[w0][0] @ w00 * r1\n"
"vext.32 q10, q3, q4, #2 @ r1, shift left 2\n"
"vext.32 q11, q4, q5, #2 @ r1, shift left 2\n"
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n"
"vmla.f32 q13, q9, %e[w1][1] @ w11 * r1\n"
"vmla.f32 q14, q8, %e[w0][1] @ w01 * r1\n"
"vmla.f32 q15, q9, %e[w0][1] @ w01 * r1\n"
"vld1.32 {d6-d9}, [%[r2]]! @ load r2\n"
"vmla.f32 q12, q10, %f[w1][0] @ w12 * r1\n"
"vmla.f32 q13, q11, %f[w1][0] @ w12 * r1\n"
"vmla.f32 q14, q10, %f[w0][0] @ w02 * r1\n"
"vmla.f32 q15, q11, %f[w0][0] @ w02 * r1\n"
"vld1.32 {d10}, [%[r2]] @ load r2\n"
/* r2 * w6, w7, w8, get out r0*/
/* r2 * w3, w4, w5, get out r1*/
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, "
"0, 1, 2, 3\n"
"vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, "
"4, 5, 6, 7\n"
"vext.32 q8, q3, q4, #1 @ r2, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r2, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, "
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r2, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r2, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, "
"5, 6, 7, 8\n"
"vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, "
"6, 7, 8, 9\n"
"vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r3]] @ load r3, 2 "
"float\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n"
"vmla.f32 q13, q4, %e[w2][0] @ w20 * r2\n"
"vext.32 q8, q3, q4, #1 @ r2, shift left 1\n"
"vext.32 q9, q4, q5, #1 @ r2, shift left 1\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r2\n"
"vmla.f32 q15, q4, %e[w1][0] @ w10 * r2\n"
"vext.32 q10, q3, q4, #2 @ r2, shift left 2\n"
"vext.32 q11, q4, q5, #2 @ r2, shift left 2\n"
"vmla.f32 q12, q8, %e[w2][1] @ w21 * r2\n"
"vmla.f32 q13, q9, %e[w2][1] @ w21 * r2\n"
"vmla.f32 q14, q8, %e[w1][1] @ w11 * r2\n"
"vmla.f32 q15, q9, %e[w1][1] @ w11 * r2\n"
"vld1.32 {d6-d9}, [%[r3]]! @ load r3\n"
"vmla.f32 q12, q10, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q13, q11, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q14, q10, %f[w1][0] @ w12 * r2\n"
"vmla.f32 q15, q11, %f[w1][0] @ w12 * r2\n"
"vld1.32 {d10}, [%[r3]] @ load r3\n"
/* r3 * w6, w7, w8, get out r1*/
"vext.32 q8, q3, q4, #1 @ r3, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r3, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, "
"4, 5, 6, 7\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, "
"or01\n"
"vext.32 q10, q3, q4, #2 @ r3, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r3, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, "
"4, 5, 6, 7\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 "
"float\n"
"vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r0]] @ load r0, 2 "
"float\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, "
"or11\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
"vext.32 q8, q3, q4, #1 @ r3, shift left 1\n"
"vext.32 q9, q4, q5, #1 @ r3, shift left 1\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r3\n"
"vmla.f32 q15, q4, %e[w2][0] @ w20 * r3\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, or01\n"
"vext.32 q10, q3, q4, #2 @ r3, shift left 2\n"
"vext.32 q11, q4, q5, #2 @ r3, shift left 2\n"
"vmla.f32 q14, q8, %e[w2][1] @ w21 * r3\n"
"vmla.f32 q15, q9, %e[w2][1] @ w21 * r3\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00,or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r3\n"
"vmla.f32 q14, q10, %f[w2][0] @ w22 * r3\n"
"vmla.f32 q15, q11, %f[w2][0] @ w22 * r3\n"
"vld1.32 {d10}, [%[r0]] @ load r0\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, or11\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ jump to main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
: "cc","memory","q3","q4",
"q5","q6","q7","q8","q9","q10",
"q11","q12","q13","q14","q15"
);
// clang-format on
r0 -= 8;
}
//! deal with remain ow
......
......@@ -24,6 +24,39 @@ namespace lite {
namespace arm {
namespace math {
const int OUT_C_BLOCK = 4;
const int OUT_H_BLOCK = 2;
const int OUT_W_BLOCK = 4;
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx) {
auto dim_in = param.x->dims();
auto dim_out = param.output->dims();
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0];
int ow = dim_out[3];
int oh = dim_out[2];
int ic = dim_in[1];
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round * 2 /*stride_w*/ + 1;
const int hin_r_block = OUT_H_BLOCK * 2 /*stride_h*/ + 1;
int hout_r_block =
(llc_size - 2 * wout_round * ic - ic) /
((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size);
}
void conv_3x3s2_direct_fp32(const float* i_data,
float* o_data,
int bs,
......@@ -44,53 +77,50 @@ void conv_3x3s2_direct_fp32(const float* i_data,
int l2_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0];
const int hout_c_block = 4;
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block;
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round * 2 /*stride_w*/ + 1;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
//! get h block
//! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block
//! win_round * ic * hin_r_block + wout_round * OUT_C_BLOCK * hout_r_block
//! * threads = l2_size
//! win_round = 2 * wout_round + 1
//! hin_r_block = 2 * hout_r_block + 1
int hout_r_block =
(l2_size - 2 * wout_round * ic - ic) /
((4 * wout_round + 2) * ic + wout_round * hout_c_block * threads);
((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
float* tmp_work_space = ctx->workspace_data<float>();
float ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(float) * win_round);
float ptr_write[wout_round]; // NOLINT
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round;
//! l2_cache start
float* pre_din = tmp_work_space;
int size_in_channel = win * ih;
int size_out_channel = ow * oh;
int w_stride = ic * 9; /*kernel_w * kernel_h*/
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h *
int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h *
int ws = -pad_w;
int we = ws + win_round;
int w_loop = wout_round / 4;
int c_remain = oc - (oc / hout_c_block) * hout_c_block;
int c_round_down = (oc / hout_c_block) * hout_c_block;
int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int out_row_stride = hout_c_block * wout_round;
int out_row_stride = OUT_C_BLOCK * wout_round;
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
......@@ -114,7 +144,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* cblock_inr4 = cblock_inr3 + in_len;
#pragma omp parallel for num_threads(threads)
for (int c = 0; c < c_round_down; c += hout_c_block) {
for (int c = 0; c < c_round_down; c += OUT_C_BLOCK) {
#ifdef ARM_WITH_OMP
float* pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
......@@ -133,9 +163,9 @@ void conv_3x3s2_direct_fp32(const float* i_data,
bias_ptr = bias + c;
}
fill_packed_biasc4(
pre_out, bias_ptr, wout_round * hout_c_block * h_kernel);
pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c;
const float* inr0 = block_inr0;
......@@ -168,18 +198,15 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4;
int cnt = w_loop;
// clang-format off
asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00,
outr01*/
"ldp q15, q16, [%[ptr_out0]]\n" /* load outr00, outr01*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"2: \n" /* main loop*/
"ldp q0, q1, [%[r0]], #32\n" /* load input r0*/
"ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
"ldp q4, q5, [%[r2]], #32\n" /* load input r2*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"2:\n" /* main loop*/
/* r0, r2, mul w0, get out r0, r1 */
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
......@@ -191,18 +218,13 @@ void conv_3x3s2_direct_fp32(const float* i_data,
"fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/
"fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/
"fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/
"ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/
/* r2 mul w6, get out r0*/
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/
"fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/
"fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th
element*/
"ldr d11, [%[r1]]\n" /* load input r1, 9th element*/
/* r0, r2, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
"fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/
......@@ -212,42 +234,29 @@ void conv_3x3s2_direct_fp32(const float* i_data,
"fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/
"fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/
"fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/
"ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/
/* r2 mul w7, get out r0 */
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/
"fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/
"fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/
"ldr d13, [%[r3]] \n" /* load input r3, 9th
element*/
"ldr d13, [%[r3]]\n" /* load input r3, 9th element*/
/* r0, r2, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/
"fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/
"fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 *
r0[8]*/
"fmla v18.4s , %[w2].4s, v10.s[0]\n"/* outr03 = w2 * r0[8]*/
"fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/
"fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/
"fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/
"fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 *
r2[8]*/
"fmla v22.4s , %[w2].4s, v12.s[0]\n"/* outr13 = w2 * r2[8]*/
"ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/
"fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/
"fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 *
r2[8]*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th
element*/
"fmla v18.4s , %[w8].4s, v12.s[0]\n"/* outr03 = w8 * r2[8]*/
"ldr d14, [%[r4]]\n" /* load input r4, 9th element*/
/* r1, r3, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/
......@@ -257,9 +266,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
"fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/
"fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/
"fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/
/* r1, r3, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
"fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/
......@@ -269,104 +276,55 @@ void conv_3x3s2_direct_fp32(const float* i_data,
"fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/
"fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/
"fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
/* r1, r3, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/
"fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/
"fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 *
r1[8]*/
"fmla v18.4s , %[w5].4s, v11.s[0]\n"/* outr03 = w5 * r1[8]*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/
"fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/
"fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/
"fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 *
r3[8]*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"fmla v22.4s , %[w5].4s, v13.s[0]\n"/* outr13 = w5 * r3[8]*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
/* r4, mul w6, get out r1 */
"fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/
"fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/
"fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/
"fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
/* r4, mul w7, get out r1 */
"fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/
"fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/
"fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/
"fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
/* r4, mul w8, get out r1 */
"fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/
"fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/
"fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/
"fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 *
r4[8]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"fmla v22.4s , %[w8].4s, v14.s[0]\n"/* outr13 = w8 * r4[8]*/
"subs %w[cnt], %w[cnt], #1\n" /*loop count -1*/
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
"stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
: [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3), [r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
wc0 += 9 * hout_c_block;
[w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4),
[w5] "w"(w5), [w6] "w"(w6),
[w7] "w"(w7), [w8] "w"(w8)
: "cc","memory","v0","v1","v2","v3","v4",
"v5","v6","v7","v8","v9","v10","v11","v12","v13",
"v14","v15","v16","v17","v18","v19","v20","v21","v22");
// clang-format on
wc0 += 9 * OUT_C_BLOCK;
inr0 += win_round;
inr1 += win_round;
inr2 += win_round;
......@@ -387,285 +345,142 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4;
int cnt = w_loop;
// clang-format off
asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ "
"load outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
/* load r0, r2 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, "
"8 float\n"
"vld1.32 {d8}, [%[r0]] @ load r0, "
"9th float\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"vld1.32 {d0-d3}, [%[r0]]! @ load r0\n"
"vld1.32 {d8}, [%[r0]] @ load r0\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 -32\n"
/* main loop */
"0: @ main "
"loop\n"
"0: @ main loop\n"
/* mul r0, with w0, w1, w2 */
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load "
"outr1, w0, w1, c0~c3\n"
"vmla.f32 q8, q5, d0[0] @ w0 * "
"inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load "
"outr1, w2, w3, c0~c3\n"
"vmla.f32 q9, q5, d1[0] @ w0 * "
"inr02\n"
"vmla.f32 q10, q5, d2[0] @ w0 * "
"inr04\n"
"vmla.f32 q11, q5, d3[0] @ w0 * "
"inr06\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w1 * "
"inr01\n"
"vmla.f32 q9, q6, d1[1] @ w1 * "
"inr03\n"
"vmla.f32 q10, q6, d2[1] @ w1 * "
"inr05\n"
"vmla.f32 q11, q6, d3[1] @ w1 * "
"inr07\n"
"vld1.32 {d9}, [%[r2]] @ load r2, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w2 * "
"inr02\n"
"vmla.f32 q9, q7, d2[0] @ w2 * "
"inr04\n"
"vmla.f32 q10, q7, d3[0] @ w2 * "
"inr06\n"
"vmla.f32 q11, q7, d8[0] @ w2 * "
"inr08\n"
"sub %[r2], %[r2], #32 @ r2 - 32, "
"load r2 twice\n"
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n"
"vmla.f32 q8, q5, d0[0] @ w0 * inr00\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n"
"vmla.f32 q9, q5, d1[0] @ w0 * inr02\n"
"vmla.f32 q10, q5, d2[0] @ w0 * inr04\n"
"vmla.f32 q11, q5, d3[0] @ w0 * inr06\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2\n"
"vmla.f32 q8, q6, d0[1] @ w1 * inr01\n"
"vmla.f32 q9, q6, d1[1] @ w1 * inr03\n"
"vmla.f32 q10, q6, d2[1] @ w1 * inr05\n"
"vmla.f32 q11, q6, d3[1] @ w1 * inr07\n"
"vld1.32 {d9}, [%[r2]] @ load r2, 9th float\n"
"vmla.f32 q8, q7, d1[0] @ w2 * inr02\n"
"vmla.f32 q9, q7, d2[0] @ w2 * inr04\n"
"vmla.f32 q10, q7, d3[0] @ w2 * inr06\n"
"vmla.f32 q11, q7, d8[0] @ w2 * inr08\n"
"sub %[r2], %[r2], #32 @ r2 - 32\n"
/* mul r2, with w0, w1, w2 */
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w0 * "
"inr20\n"
"vmla.f32 q13, q5, d5[0] @ w0 * "
"inr22\n"
"vmla.f32 q14, q5, d6[0] @ w0 * "
"inr24\n"
"vmla.f32 q15, q5, d7[0] @ w0 * "
"inr26\n"
"vld1.32 {d8}, [%[r1]] @ load r1, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w1 * "
"inr21\n"
"vmla.f32 q13, q6, d5[1] @ w1 * "
"inr23\n"
"vmla.f32 q14, q6, d6[1] @ w1 * "
"inr25\n"
"vmla.f32 q15, q6, d7[1] @ w1 * "
"inr27\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, "
"w4, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w2 * "
"inr22\n"
"vmla.f32 q13, q7, d6[0] @ w2 * "
"inr24\n"
"vmla.f32 q14, q7, d7[0] @ w2 * "
"inr26\n"
"vmla.f32 q15, q7, d9[0] @ w2 * "
"inr28\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, "
"to q7\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r1\n"
"vmla.f32 q12, q5, d4[0] @ w0 * inr20\n"
"vmla.f32 q13, q5, d5[0] @ w0 * inr22\n"
"vmla.f32 q14, q5, d6[0] @ w0 * inr24\n"
"vmla.f32 q15, q5, d7[0] @ w0 * inr26\n"
"vld1.32 {d8}, [%[r1]] @ load r1, 9th float\n"
"vmla.f32 q12, q6, d4[1] @ w1 * inr21\n"
"vmla.f32 q13, q6, d5[1] @ w1 * inr23\n"
"vmla.f32 q14, q6, d6[1] @ w1 * inr25\n"
"vmla.f32 q15, q6, d7[1] @ w1 * inr27\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w2 * inr22\n"
"vmla.f32 q13, q7, d6[0] @ w2 * inr24\n"
"vmla.f32 q14, q7, d7[0] @ w2 * inr26\n"
"vmla.f32 q15, q7, d9[0] @ w2 * inr28\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, to q7\n"
/* mul r1, with w3, w4, w5 */
"vmla.f32 q8, q5, d0[0] @ w3 * "
"inr10\n"
"vmla.f32 q9, q5, d1[0] @ w3 * "
"inr12\n"
"vmla.f32 q10, q5, d2[0] @ w3 * "
"inr14\n"
"vmla.f32 q11, q5, d3[0] @ w3 * "
"inr16\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w4 * "
"inr11\n"
"vmla.f32 q9, q6, d1[1] @ w4 * "
"inr13\n"
"vmla.f32 q10, q6, d2[1] @ w4 * "
"inr15\n"
"vmla.f32 q11, q6, d3[1] @ w4 * "
"inr17\n"
"vld1.32 {d9}, [%[r3]] @ load r3, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w5 * "
"inr12\n"
"vmla.f32 q9, q7, d2[0] @ w5 * "
"inr14\n"
"vmla.f32 q10, q7, d3[0] @ w5 * "
"inr16\n"
"vmla.f32 q11, q7, d8[0] @ w5 * "
"inr18\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 "
"- 32, to start address\n"
"vmla.f32 q8, q5, d0[0] @ w3 * inr10\n"
"vmla.f32 q9, q5, d1[0] @ w3 * inr12\n"
"vmla.f32 q10, q5, d2[0] @ w3 * inr14\n"
"vmla.f32 q11, q5, d3[0] @ w3 * inr16\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, 8 float\n"
"vmla.f32 q8, q6, d0[1] @ w4 * inr11\n"
"vmla.f32 q9, q6, d1[1] @ w4 * inr13\n"
"vmla.f32 q10, q6, d2[1] @ w4 * inr15\n"
"vmla.f32 q11, q6, d3[1] @ w4 * inr17\n"
"vld1.32 {d9}, [%[r3]] @ load r3, 9th float\n"
"vmla.f32 q8, q7, d1[0] @ w5 * inr12\n"
"vmla.f32 q9, q7, d2[0] @ w5 * inr14\n"
"vmla.f32 q10, q7, d3[0] @ w5 * inr16\n"
"vmla.f32 q11, q7, d8[0] @ w5 * inr18\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n"
/* mul r3, with w3, w4, w5 */
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w3 * "
"inr30\n"
"vmla.f32 q13, q5, d5[0] @ w3 * "
"inr32\n"
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr34\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr36\n"
"vld1.32 {d8}, [%[r2]] @ load r2, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr31\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr33\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr35\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr37\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr32\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr34\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr36\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr38\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2\n"
"vmla.f32 q12, q5, d4[0] @ w3 * inr30\n"
"vmla.f32 q13, q5, d5[0] @ w3 * inr32\n"
"vmla.f32 q14, q5, d6[0] @ w3 * inr34\n"
"vmla.f32 q15, q5, d7[0] @ w3 * inr36\n"
"vld1.32 {d8}, [%[r2]] @ load r2, 9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * inr31\n"
"vmla.f32 q13, q6, d5[1] @ w4 * inr33\n"
"vmla.f32 q14, q6, d6[1] @ w4 * inr35\n"
"vmla.f32 q15, q6, d7[1] @ w4 * inr37\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n"
"vmla.f32 q12, q7, d5[0] @ w5 * inr32\n"
"vmla.f32 q13, q7, d6[0] @ w5 * inr34\n"
"vmla.f32 q14, q7, d7[0] @ w5 * inr36\n"
"vmla.f32 q15, q7, d9[0] @ w5 * inr38\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n"
/* mul r2, with w6, w7, w8 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q10, q5, d2[0] @ w6 * "
"inr24\n"
"vmla.f32 q11, q5, d3[0] @ w6 * "
"inr26\n"
"vld1.32 {d4-d7}, [%[r4]]! @ load r4, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q10, q6, d2[1] @ w7 * "
"inr25\n"
"vmla.f32 q11, q6, d3[1] @ w7 * "
"inr27\n"
"vld1.32 {d9}, [%[r4]] @ load r4, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q10, q7, d3[0] @ w8 * "
"inr26\n"
"vmla.f32 q11, q7, d8[0] @ w8 * "
"inr28\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
"vmla.f32 q8, q5, d0[0] @ w6 * inr20\n"
"vmla.f32 q9, q5, d1[0] @ w6 * inr22\n"
"vmla.f32 q10, q5, d2[0] @ w6 * inr24\n"
"vmla.f32 q11, q5, d3[0] @ w6 * inr26\n"
"vld1.32 {d4-d7}, [%[r4]]! @ load r4\n"
"vmla.f32 q8, q6, d0[1] @ w7 * inr21\n"
"vmla.f32 q9, q6, d1[1] @ w7 * inr23\n"
"vmla.f32 q10, q6, d2[1] @ w7 * inr25\n"
"vmla.f32 q11, q6, d3[1] @ w7 * inr27\n"
"vld1.32 {d9}, [%[r4]] @ load r4, 9th float\n"
"vmla.f32 q8, q7, d1[0] @ w8 * inr22\n"
"vmla.f32 q9, q7, d2[0] @ w8 * inr24\n"
"vmla.f32 q10, q7, d3[0] @ w8 * inr26\n"
"vmla.f32 q11, q7, d8[0] @ w8 * inr28\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144\n"
/* mul r4, with w6, w7, w8 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w3 * "
"inr40\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q13, q5, d5[0] @ w3 * "
"inr42\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr44\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr46\n"
"vld1.32 {d8}, [%[r0]] @ load "
"r0, 9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr41\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr43\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr45\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr47\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr42\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr44\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr46\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr48\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
"vld1.32 {d0-d3}, [%[r0]]! @ load r0\n"
"vmla.f32 q12, q5, d4[0] @ w3 * inr40\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n"
"vmla.f32 q13, q5, d5[0] @ w3 * inr42\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n"
"vmla.f32 q14, q5, d6[0] @ w3 * inr44\n"
"vmla.f32 q15, q5, d7[0] @ w3 * inr46\n"
"vld1.32 {d8}, [%[r0]] @ load r0, 9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * inr41\n"
"vmla.f32 q13, q6, d5[1] @ w4 * inr43\n"
"vmla.f32 q14, q6, d6[1] @ w4 * inr45\n"
"vmla.f32 q15, q6, d7[1] @ w4 * inr47\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vmla.f32 q12, q7, d5[0] @ w5 * inr42\n"
"vmla.f32 q13, q7, d6[0] @ w5 * inr44\n"
"vmla.f32 q14, q7, d7[0] @ w5 * inr46\n"
"vmla.f32 q15, q7, d9[0] @ w5 * inr48\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
"subs %[cnt], #1 @ loop count--\n"
"bne 0b @ jump to main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
[wc0] "+r"(wc0)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
: "cc","memory","q0","q1","q2","q3","q4",
"q5","q6","q7","q8","q9","q10",
"q11","q12","q13","q14","q15"
);
// clang-format on
inr0 += win_round;
inr1 += win_round;
......@@ -684,7 +499,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
write_to_output_c4_fp32(pre_out,
dout_batch,
c,
c + hout_c_block,
c + OUT_C_BLOCK,
h,
h + h_kernel,
0,
......@@ -721,7 +536,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
}
fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) {
for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c;
const float* inr0 = block_inr0;
......@@ -755,158 +570,80 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4;
int cnt = w_loop;
// clang-format off
asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr00,
outr01,
outr02,
outr03*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"2: \n" /* main loop*/
"ldr q21, [%[ptr_out0]]\n" /* load outr00-outr03*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/
"ldr d10, [%[r0]]\n"/* load input r0, 9th element*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32\n" /* load input r2*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"2:\n" /* main loop*/
/* r0, r2, mul w0, get out r0, r1 */
"ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11,
outr12, outr13*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2,
4, 6]*/
"fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2,
4, 6]*/
"ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/
"ldr q22, [%[ptr_out1]]\n" /* load outr10 - outr13*/
"fmla v21.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0*/
"fmla v22.4s , %[w0].4s, v4.4s\n" /* outr1 = w0 * r2*/
"ld2 {v2.4s, v3.4s}, [%[r1]], #32\n" /* load input r1*/
/* r2 mul w6, get out r0*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2,
4, 6]*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th
element*/
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr0 = w6 * r2*/
"ldr d11, [%[r1]]\n" /* load input r1, 9th element*/
/* shift left 1 */
"ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/
"ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/
/* r0, r2, mul w1, get out r0, r1 */
"fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3,
5, 7]*/
"fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3,
5, 7]*/
"ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/
"fmla v21.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0*/
"fmla v22.4s , %[w1].4s, v5.4s\n" /* outr1 = w1 * r2*/
"ld2 {v6.4s, v7.4s}, [%[r3]], #32\n" /* load input r3*/
/* r2 mul w7, get out r0 */
"fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1,
3, 5, 7]*/
"ldr d13, [%[r3]] \n" /* load input r3, 9th
element*/
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr00 = w7 * r2*/
"ldr d13, [%[r3]]\n" /* load input r3, 9th element*/
/* r0, r2, mul w2, get out r0, r1 */
"fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4,
6, 8]*/
"fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4,
6, 8]*/
"fmla v21.4s , %[w2].4s, v15.4s\n" /* outr0 = w2 * r0*/
"fmla v22.4s , %[w2].4s, v16.4s\n" /* outr1 = w2 * r2*/
"ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */
"fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2,
4, 6, 8]*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th
element*/
"fmla v21.4s , %[w8].4s, v16.4s\n" /* outr00 = w8 * r2*/
"ldr d14, [%[r4]]\n" /* load input r4, 9th element*/
/* r1, r3, mul w3, get out r0, r1 */
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2,
4, 6]*/
"fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2,
4, 6]*/
"fmla v21.4s , %[w3].4s, v2.4s\n" /* outr0 = w3 * r1*/
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr1 = w3 * r3*/
/* shift left 1 */
"ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/
"ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/
/* r1, r3, mul w4, get out r0, r1 */
"fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3,
5, 7]*/
"fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3,
5, 7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr0 = w4 * r1*/
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr1 = w4 * r3*/
"ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
/* r1, r3, mul w5, get out r0, r1 */
"fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/
"fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/
"fmla v21.4s , %[w5].4s, v15.4s\n" /* outr0 = w5 * r1[2]*/
"fmla v22.4s , %[w5].4s, v16.4s\n" /* outr1 = w5 * r1[4]*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"str q21, [%[ptr_out0]], #16\n" /* save outr00, outr01*/
/* r4, mul w6, get out r1 */
"fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2,
4, 6]*/
"fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4*/
"ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0*/
/* r4, mul w7, get out r1 */
"fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3,
5, 7]*/
"fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4*/
/* r4, mul w8, get out r1 */
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4,
6, 8]*/
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"str q22, [%[ptr_out1]], #16 \n" /* save outr1*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v21",
"v22");
: [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
[w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
[w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
: "cc","memory","v0","v1","v2","v3",
"v4","v5","v6","v7","v8","v9","v10","v11",
"v12","v13","v14","v15","v16","v21","v22");
// clang-format on
wc0 += 36;
inr0 += win_round;
inr1 += win_round;
......@@ -944,184 +681,92 @@ void conv_3x3s2_direct_fp32(const float* i_data,
int cnt = w_loop / 2;
if (cnt > 0) {
// clang-format off
asm volatile(
/* main loop */
"0: @ "
"main loop\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, "
"or11\n"
"vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 "
"float, interleave\n"
"vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 "
"float, interleave\n"
"vld1.32 {d22}, [%[r2]] @ load 16th "
"float\n"
"0: @ main loop\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, or11\n"
"vld2.32 {d6-d9}, [%[r2]]! @ load r2\n"
"vld2.32 {d10-d13}, [%[r2]]! @ load r2\n"
"vld1.32 {d22}, [%[r2]] @ load 16th float\n"
/* r2 * w2, r2 * w0, get or0, or1 */
"vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, "
"9, 11, 13, 15\n"
"vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 "
"float, interleave\n"
"vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, "
"9, 11, 13, 15\n"
"vext.32 q4, q3, q5, #1 @ r2, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r2, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, "
"8, 10, 12, 14\n"
"vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 "
"float, interleave\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, "
"8, 10, 12, 14\n"
"vld1.32 {d22}, [%[r0]] @ load 16th "
"float\n"
"vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, "
"2, 4, 6, 8\n"
"vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, "
"2, 4, 6, 8\n"
"vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 "
"float, interleave\n"
"vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, "
"10, 12, 14, 16\n"
"vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, "
"10, 12, 14, 16\n"
"vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 "
"float, interleave\n"
"vmla.f32 q12, q4, %e[w2][1] @ w21 * r2\n"
"vmla.f32 q13, q6, %e[w2][1] @ w21 * r2\n "
"vld2.32 {d14-d17}, [%[r0]]! @ load r0\n"
"vmla.f32 q14, q4, %e[w0][1] @ w01 * r2\n"
"vmla.f32 q15, q6, %e[w0][1] @ w01 * r2\n"
"vext.32 q4, q3, q5, #1 @ r2, shift left 1\n"
"vext.32 q6, q5, q11, #1 @ r2, shift left 1\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n"
"vmla.f32 q13, q5, %e[w2][0] @ w20 * r2\n"
"vld2.32 {d18-d21}, [%[r0]]! @ load r0\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r2\n"
"vmla.f32 q15, q5, %e[w0][0] @ w00 * r2\n"
"vld1.32 {d22}, [%[r0]] @ load 16th float\n"
"vmla.f32 q12, q4, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q14, q4, %f[w0][0] @ w02 * r2\n"
"vld2.32 {d6-d9}, [%[r3]]! @ load r3\n"
"vmla.f32 q13, q6, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q15, q6, %f[w0][0] @ w02 * r2\n"
"vld2.32 {d10-d13}, [%[r3]]! @ load r3\n"
/* r0 * w0, get or0, r3 * w1, get or1*/
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, "
"9, 11, 13, 15\n"
"vext.32 q8, q7, q9, #1 @ r0, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q10, q9, q11, #1 @ r0, shift "
"left 1, get 10, 12, 14, 16\n"
"vld1.32 {d22}, [%[r3]] @ load 16th "
"float\n"
"vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, "
"9, 11, 13, 15\n"
"vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, "
"8, 10, 12, 14\n"
"vext.32 q4, q3, q5, #1 @ r3, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r3, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n"
"vmla.f32 q13, q10, %e[w0][1] @ w01 * r0\n"
"vext.32 q8, q7, q9, #1 @ r0, shift left 1\n"
"vext.32 q10, q9, q11, #1 @ r0, shift left 1\n"
"vld1.32 {d22}, [%[r3]] @ load 16th float\n"
"vmla.f32 q14, q4, %e[w1][1] @ w11 * r3\n"
"vmla.f32 q15, q6, %e[w1][1] @ w11 * r3\n"
"vmla.f32 q12, q7, %e[w0][0] @ w00 * r0\n"
"vmla.f32 q13, q9, %e[w0][0] @ w00 * r0\n"
"vext.32 q4, q3, q5, #1 @ r3, shift left 1\n"
"vext.32 q6, q5, q11, #1 @ r3, shift left 1\n"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r3\n"
"vmla.f32 q15, q5, %e[w1][0] @ w10 * r3\n"
"vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, "
"2, 4, 6, 8\n"
"vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, "
"10, 12, 14, 16\n"
"vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, "
"2, 4, 6, 8\n"
"vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, "
"10, 12, 14, 16\n"
"vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vld1.32 {d22}, [%[r1]] @ load 16th "
"float\n"
"vld2.32 {d14-d17}, [%[r1]]! @ load r1\n"
"vmla.f32 q13, q10,%f[w0][0] @ w02 * r0\n"
"vld2.32 {d18-d21}, [%[r1]]! @ load r1\n"
"vmla.f32 q14, q4, %f[w1][0] @ w12 * r3\n"
"vld2.32 {d6-d9}, [%[r4]]! @ load r4\n"
"vmla.f32 q15, q6, %f[w1][0] @ w12 * r3\n"
"vld2.32 {d10-d13}, [%[r4]]! @ load r4\n"
"vld1.32 {d22}, [%[r1]] @ load 16th float\n"
/* r1 * w1, get or0, r4 * w2, get or1 */
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, "
"9, 11, 13, 15\n"
"vext.32 q8, q7, q9, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q10, q9, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, "
"9, 11, 13, 15\n"
"vld1.32 {d22}, [%[r4]] @ load 16th "
"float\n"
"vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, "
"8, 10, 12, 14\n"
"vext.32 q4, q3, q5, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, "
"2, 4, 6, 8\n"
"vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, "
"10, 12, 14, 16\n"
"vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, "
"2, 4, 6, 8\n"
"vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, "
"10, 12, 14, 16\n"
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n"
"vmla.f32 q13, q10, %e[w1][1] @ w11 * r1\n"
"vext.32 q8, q7, q9, #1 @ r1, shift left 1\n"
"vext.32 q10, q9, q11, #1 @ r1, shift left 1\n"
"vmla.f32 q14, q4, %e[w2][1] @ w21 * r4\n"
"vmla.f32 q15, q6, %e[w2][1] @ w21 * r4\n"
"vld1.32 {d22}, [%[r4]] @ load 16th float\n"
"vmla.f32 q12, q7, %e[w1][0] @ w10 * r1\n"
"vmla.f32 q13, q9, %e[w1][0] @ w10 * r1\n"
"vext.32 q4, q3, q5, #1 @ r1, shift left 1\n"
"vext.32 q6, q5, q11, #1 @ r1, shift left 1\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r4\n"
"vmla.f32 q15, q5, %e[w2][0] @ w20 * r4\n"
"vmla.f32 q12, q8, %f[w1][0] @ w12 * r1\n"
"vmla.f32 q13, q10, %f[w1][0] @ w12 * r1\n"
"vmla.f32 q14, q4, %f[w2][0] @ w22 * r4\n"
"vmla.f32 q15, q6, %f[w2][0] @ w22 * r4\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ jump to main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
[r0] "+r"(r0),[r1] "+r"(r1),[r2] "+r"(r2),
[r3] "+r"(r3),[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
: "cc","memory","q3","q4",
"q5","q6","q7","q8","q9","q10",
"q11","q12","q13","q14","q15"
);
// clang-format on
}
//! deal with remain ow
if (w_loop & 1) {
......
......@@ -23,6 +23,9 @@ namespace lite {
namespace arm {
namespace math {
/// conv 3x3s1
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_direct_fp32(const float* din,
float* dout,
int num,
......@@ -53,6 +56,9 @@ void conv_3x3s1_direct_int8(const int8_t* din,
ARMContext* ctx,
const float* scale);
/// conv3x3s2
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s2_direct_fp32(const float* din,
float* dout,
int num,
......
......@@ -1104,13 +1104,13 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) {
SetCacheInfo(0, 1, l1size);
SetCacheInfo(1, 1, l2size);
SetCacheInfo(2, 1, l3size);
workspace_.Resize({2 * (l1size + l2size)});
workspace_.Resize({llc_size()});
workspace_.mutable_data<int8_t>();
}
bool DeviceInfo::ExtendWorkspace(int size) {
bool DeviceInfo::ExtendWorkspace(size_t size) {
workspace_.Resize({size + llc_size()});
workspace_.mutable_data<int8_t>();
return true;
return workspace_.mutable_data<int8_t>() != nullptr;
}
#endif // LITE_WITH_ARM
......
......@@ -73,7 +73,7 @@ class DeviceInfo {
T* workspace_data() {
return reinterpret_cast<T*>(workspace_.mutable_data<int8_t>());
}
bool ExtendWorkspace(int size);
bool ExtendWorkspace(size_t size);
private:
int core_num_;
......
......@@ -19,6 +19,25 @@ namespace lite {
namespace kernels {
namespace arm {
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
auto& ctx = this->ctx_->template As<ARMContext>();
if (param.strides[0] == 2) {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx));
} else {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx));
}
}
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -70,6 +89,9 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -126,6 +148,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......
......@@ -178,10 +178,12 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
w_scale_);
}
virtual void ReInitWhenNeeded();
virtual void Run();
/// todo, support inplace weights transform
protected:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册