提交 f03217b4 编写于 作者: X Xiaoyang LI 提交者: Yan Chunwei

add workspace compute funcs for direct conv, test=develop (#2132)

上级 4de0f778
...@@ -26,6 +26,39 @@ namespace lite { ...@@ -26,6 +26,39 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
const int OUT_C_BLOCK = 4;
const int OUT_H_BLOCK = 2;
const int OUT_W_BLOCK = 4;
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx) {
auto dim_in = param.x->dims();
auto dim_out = param.output->dims();
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0];
int ow = dim_out[3];
int oh = dim_out[2];
int ic = dim_in[1];
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round + 2;
int hout_r_block = (llc_size - 2 * win_round * ic) /
(win_round * ic + OUT_C_BLOCK * wout_round * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block + 2;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size);
}
void conv_3x3s1_direct_fp32(const float* i_data, void conv_3x3s1_direct_fp32(const float* i_data,
float* o_data, float* o_data,
int bs, int bs,
...@@ -44,19 +77,16 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -44,19 +77,16 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int pad_h = param.paddings[0]; const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1]; const int pad_w = param.paddings[1];
const int hout_c_block = 4; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round + 2; const int win_round = wout_round + 2;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
int hout_r_block = (l2_size - 2 * win_round * ic) / int hout_r_block = (l2_size - 2 * win_round * ic) /
(win_round * ic + hout_c_block * wout_round * threads); (win_round * ic + OUT_C_BLOCK * wout_round * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block; hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block + 2; const int hin_r_block = hout_r_block + 2;
...@@ -67,23 +97,23 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -67,23 +97,23 @@ void conv_3x3s1_direct_fp32(const float* i_data,
int in_len = win_round * ic; int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len; int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round; int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
float* pre_din = tmp_work_space; float* pre_din = tmp_work_space;
int size_in_channel = win * ih; int size_in_channel = win * ih;
int size_out_channel = ow * oh; int size_out_channel = ow * oh;
int w_stride = ic * 9; // kernel_w * kernel_h; int w_stride = ic * 9; // kernel_w * kernel_h;
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h *
int ws = -pad_w; int ws = -pad_w;
int we = ws + win_round; int we = ws + win_round;
int w_loop = wout_round / 4; int w_loop = wout_round / 4;
int c_remain = oc - (oc / hout_c_block) * hout_c_block; int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int c_round_down = (oc / hout_c_block) * hout_c_block; int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int out_row_stride = hout_c_block * wout_round; int out_row_stride = OUT_C_BLOCK * wout_round;
for (int n = 0; n < bs; ++n) { for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel; const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel; float* dout_batch = o_data + n * oc * size_out_channel;
...@@ -97,7 +127,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -97,7 +127,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
prepack_input_nxw( prepack_input_nxw(
din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero); din_batch, pre_din, 0, ic, hs, he, ws, we, ic, win, ih, ptr_zero);
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
for (int c = 0; c < oc - (hout_c_block - 1); c += hout_c_block) { for (int c = 0; c < oc - (OUT_C_BLOCK - 1); c += OUT_C_BLOCK) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
float* pre_out = float* pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
...@@ -115,9 +145,9 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -115,9 +145,9 @@ void conv_3x3s1_direct_fp32(const float* i_data,
bias_ptr = bias + c; bias_ptr = bias + c;
} }
fill_packed_biasc4( fill_packed_biasc4(
pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c; const float* wc0 = weight_c;
const float* inr0 = block_inr0; const float* inr0 = block_inr0;
...@@ -148,161 +178,125 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -148,161 +178,125 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3; const float* r3 = inr3;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00,outr01*/
outr01*/ "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr10, outr11*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ "2: \n" /* main loop*/
"2: \n" /* main loop*/ /* r0, r1, mul w0, get out r0, r1 */
/* r0, r1, mul w0, get out r0, r1 */ "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/
"fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ "fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/
"fmla v16.4s , %[w0].4s, v0.s[1]\n" /* outr01 = w0 * r0[1]*/ "fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/
"fmla v17.4s , %[w0].4s, v0.s[2]\n" /* outr02 = w0 * r0[2]*/ "fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/
"fmla v18.4s , %[w0].4s, v0.s[3]\n" /* outr03 = w0 * r0[3]*/ "fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/
"fmla v19.4s , %[w0].4s, v2.s[0]\n" /* outr10 = w0 * r1[0]*/ "fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/
"fmla v20.4s , %[w0].4s, v2.s[1]\n" /* outr11 = w0 * r1[1]*/ "fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/
"fmla v21.4s , %[w0].4s, v2.s[2]\n" /* outr12 = w0 * r1[2]*/ "fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/
"fmla v22.4s , %[w0].4s, v2.s[3]\n" /* outr13 = w0 * r1[3]*/ /* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
/* r0, r1, mul w1, get out r0, r1 */ "fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ "fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/
"fmla v16.4s , %[w1].4s, v0.s[2]\n" /* outr01 = w1 * r0[2]*/ "fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/
"fmla v17.4s , %[w1].4s, v0.s[3]\n" /* outr02 = w1 * r0[3]*/ "fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/
"fmla v18.4s , %[w1].4s, v1.s[0]\n" /* outr03 = w1 * r0[4]*/ "fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/
"fmla v19.4s , %[w1].4s, v2.s[1]\n" /* outr10 = w1 * r1[1]*/ "fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/
"fmla v20.4s , %[w1].4s, v2.s[2]\n" /* outr11 = w1 * r1[2]*/ "fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/
"fmla v21.4s , %[w1].4s, v2.s[3]\n" /* outr12 = w1 * r1[3]*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"fmla v22.4s , %[w1].4s, v3.s[0]\n" /* outr13 = w1 * r1[4]*/ /* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ "fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/
"fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/
/* r0, r1, mul w2, get out r0, r1 */ "fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ "fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/
"fmla v16.4s , %[w2].4s, v0.s[3]\n" /* outr01 = w2 * r0[3]*/ "fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/
"fmla v17.4s , %[w2].4s, v1.s[0]\n" /* outr02 = w2 * r0[0]*/ "fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/
"fmla v18.4s , %[w2].4s, v1.s[1]\n" /* outr03 = w2 * r0[1]*/ "fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/
"fmla v19.4s , %[w2].4s, v2.s[2]\n" /* outr10 = w2 * r1[2]*/ /* r1, r2, mul w3, get out r0, r1 */
"fmla v20.4s , %[w2].4s, v2.s[3]\n" /* outr11 = w2 * r1[3]*/ "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v21.4s , %[w2].4s, v3.s[0]\n" /* outr12 = w2 * r1[0]*/ "fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/
"fmla v22.4s , %[w2].4s, v3.s[1]\n" /* outr13 = w2 * r1[1]*/ "fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/
"fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/
/* r1, r2, mul w3, get out r0, r1 */ "fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ "fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/
"fmla v16.4s , %[w3].4s, v2.s[1]\n" /* outr01 = w3 * r1[1]*/ "fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/
"fmla v17.4s , %[w3].4s, v2.s[2]\n" /* outr02 = w3 * r1[2]*/ "fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/
"fmla v18.4s , %[w3].4s, v2.s[3]\n" /* outr03 = w3 * r1[3]*/ "ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/
"fmla v19.4s , %[w3].4s, v4.s[0]\n" /* outr10 = w3 * r2[0]*/ /* r1, r2, mul w4, get out r0, r1 */
"fmla v20.4s , %[w3].4s, v4.s[1]\n" /* outr11 = w3 * r2[1]*/ "fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
"fmla v21.4s , %[w3].4s, v4.s[2]\n" /* outr12 = w3 * r2[2]*/ "fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/
"fmla v22.4s , %[w3].4s, v4.s[3]\n" /* outr13 = w3 * r2[3]*/ "fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/
"ldp q0, q1, [%[r0]], #16 \n" /* load next input r0*/ "fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/
/* r1, r2, mul w4, get out r0, r1 */ "fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ "fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/
"fmla v16.4s , %[w4].4s, v2.s[2]\n" /* outr01 = w4 * r1[2]*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"fmla v17.4s , %[w4].4s, v2.s[3]\n" /* outr02 = w4 * r1[3]*/ /* r1, r2, mul w5, get out r0, r1 */
"fmla v18.4s , %[w4].4s, v3.s[0]\n" /* outr03 = w4 * r1[4]*/ "fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
"fmla v19.4s , %[w4].4s, v4.s[1]\n" /* outr10 = w4 * r2[1]*/ "fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/
"fmla v20.4s , %[w4].4s, v4.s[2]\n" /* outr11 = w4 * r2[2]*/ "fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/
"fmla v21.4s , %[w4].4s, v4.s[3]\n" /* outr12 = w4 * r2[3]*/ "fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/
"fmla v22.4s , %[w4].4s, v5.s[0]\n" /* outr13 = w4 * r2[4]*/ "fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/
"fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ "fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/
"fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/
/* r1, r2, mul w5, get out r0, r1 */ /* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w5].4s, v2.s[3]\n" /* outr01 = w5 * r1[3]*/ "fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w5].4s, v3.s[0]\n" /* outr02 = w5 * r1[0]*/ "fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w5].4s, v3.s[1]\n" /* outr03 = w5 * r1[1]*/ "fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/
"fmla v19.4s , %[w5].4s, v4.s[2]\n" /* outr10 = w5 * r2[2]*/ "fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w5].4s, v4.s[3]\n" /* outr11 = w5 * r2[3]*/ "fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w5].4s, v5.s[0]\n" /* outr12 = w5 * r2[0]*/ "fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w5].4s, v5.s[1]\n" /* outr13 = w5 * r2[1]*/ "fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/
"ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/
/* r2, r3, mul w6, get out r0, r1 */ /* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w6].4s, v4.s[1]\n" /* outr01 = w6 * r2[1]*/ "fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w6].4s, v4.s[2]\n" /* outr02 = w6 * r2[2]*/ "fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w6].4s, v4.s[3]\n" /* outr03 = w6 * r2[3]*/ "fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/
"fmla v19.4s , %[w6].4s, v6.s[0]\n" /* outr10 = w6 * r3[0]*/ "fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w6].4s, v6.s[1]\n" /* outr11 = w6 * r3[1]*/ "fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w6].4s, v6.s[2]\n" /* outr12 = w6 * r3[2]*/ "fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w6].4s, v6.s[3]\n" /* outr13 = w6 * r3[3]*/ "fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"ldp q2, q3, [%[r1]], #16 \n" /* load next input r1*/ /* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
/* r2, r3, mul w7, get out r0, r1 */ "fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ "fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/
"fmla v16.4s , %[w7].4s, v4.s[2]\n" /* outr01 = w7 * r2[2]*/ "fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/
"fmla v17.4s , %[w7].4s, v4.s[3]\n" /* outr02 = w7 * r2[3]*/ "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v18.4s , %[w7].4s, v5.s[0]\n" /* outr03 = w7 * r2[4]*/ "fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/
"fmla v19.4s , %[w7].4s, v6.s[1]\n" /* outr10 = w7 * r3[1]*/ "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
"fmla v20.4s , %[w7].4s, v6.s[2]\n" /* outr11 = w7 * r3[2]*/ "fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w7].4s, v6.s[3]\n" /* outr12 = w7 * r3[3]*/ "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
"fmla v22.4s , %[w7].4s, v7.s[0]\n" /* outr13 = w7 * r3[4]*/ "fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ "fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
/* r2, r3, mul w8, get out r0, r1 */ "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
"fmla v16.4s , %[w8].4s, v4.s[3]\n" /* outr01 = w8 * r2[3]*/ "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"fmla v17.4s , %[w8].4s, v5.s[0]\n" /* outr02 = w8 * r2[0]*/ "bne 2b \n" /* jump to main loop*/
"fmla v18.4s , %[w8].4s, v5.s[1]\n" /* outr03 = w8 * r2[1]*/ : [cnt] "+r"(cnt),
[r0] "+r"(r0),[r1] "+r"(r1),
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ [r2] "+r"(r2),[r3] "+r"(r3),
"fmla v19.4s , %[w8].4s, v6.s[2]\n" /* outr10 = w8 * r3[2]*/ [ptr_out0] "+r"(ptr_out0),
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ [ptr_out1] "+r"(ptr_out1)
"fmla v20.4s , %[w8].4s, v6.s[3]\n" /* outr11 = w8 * r3[3]*/ : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/ [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
"fmla v21.4s , %[w8].4s, v7.s[0]\n" /* outr12 = w8 * r3[0]*/ [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ : "cc","memory","v0","v1","v2","v3",
"fmla v22.4s , %[w8].4s, v7.s[1]\n" /* outr13 = w8 * r3[1]*/ "v4","v5","v6","v7","v15","v16",
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/ "v17","v18","v19","v20","v21","v22"
"stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/ );
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ // clang-format on
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"bne 2b \n" /* jump to main loop*/ wc0 += 9 * OUT_C_BLOCK;
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
wc0 += 9 * hout_c_block;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
inr2 += win_round; inr2 += win_round;
...@@ -321,273 +315,135 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -321,273 +315,135 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3; const float* r3 = inr3;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ " "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"load outr0, w0, w1, c0~c3\n" "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load " /* load weights */
"outr0, w2, w3, c0~c3\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
/* load weights */ /* load r0, r1 */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n"
"w1, to q5, q6\n" "vld1.32 {d2}, [%[r0]] @ load r0\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
"to q7\n" /* main loop */
"0: @ main loop\n"
/* load r0, r1 */ /* mul r0 with w0, w1, w2, get out r0 */
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, " "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n"
"4 float\n" "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n"
"vld1.32 {d2}, [%[r0]] @ load r0, " "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n"
"2 float\n" "vmla.f32 q9, q5, d0[1] @ w0 * inr01\n"
"vmla.f32 q10, q5, d1[0] @ w0 * inr02\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " "vmla.f32 q11, q5, d1[1] @ w0 * inr03\n"
"- 32, to start address\n" "vld1.32 {d3-d4}, [%[r1]]! @ load r1\n"
"vmla.f32 q8, q6, d0[1] @ w1 * inr01\n"
/* main loop */ "vmla.f32 q9, q6, d1[0] @ w1 * inr02\n"
"0: @ main " "vmla.f32 q10, q6, d1[1] @ w1 * inr03\n"
"loop\n" "vmla.f32 q11, q6, d2[0] @ w1 * inr04\n"
/* mul r0 with w0, w1, w2, get out r0 */ "vld1.32 {d5}, [%[r1]] @ load r0\n"
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n"
"outr1, w0, w1, c0~c3\n" "vmla.f32 q9, q7, d1[1] @ w2 * inr03\n"
"vmla.f32 q8, q5, d0[0] @ w0 * " "vmla.f32 q10, q7, d2[0] @ w2 * inr04\n"
"inr00\n" "vmla.f32 q11, q7, d2[1] @ w2 * inr05\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load " "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n"
"outr1, w2, w3, c0~c3\n" /* mul r1 with w0, w1, w2, get out r1 */
"vmla.f32 q9, q5, d0[1] @ w0 * " "vmla.f32 q12, q5, d3[0] @ w0 * inr10\n"
"inr01\n" "vmla.f32 q13, q5, d3[1] @ w0 * inr11\n"
"vmla.f32 q10, q5, d1[0] @ w0 * " "vmla.f32 q14, q5, d4[0] @ w0 * inr12\n"
"inr02\n" "vmla.f32 q15, q5, d4[1] @ w0 * inr13\n"
"vmla.f32 q11, q5, d1[1] @ w0 * " "vmla.f32 q12, q6, d3[1] @ w1 * inr11\n"
"inr03\n" "vmla.f32 q13, q6, d4[0] @ w1 * inr12\n"
"vld1.32 {d3-d4}, [%[r1]]! @ load r1, " "vmla.f32 q14, q6, d4[1] @ w1 * inr13\n"
"4 float\n" "vmla.f32 q15, q6, d5[0] @ w1 * inr14\n"
"vmla.f32 q8, q6, d0[1] @ w1 * " "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4\n"
"inr01\n" "vmla.f32 q12, q7, d4[0] @ w2 * inr12\n"
"vmla.f32 q9, q6, d1[0] @ w1 * " "vmla.f32 q13, q7, d4[1] @ w2 * inr13\n"
"inr02\n" "vmla.f32 q14, q7, d5[0] @ w2 * inr14\n"
"vmla.f32 q10, q6, d1[1] @ w1 * " "vmla.f32 q15, q7, d5[1] @ w2 * inr15\n"
"inr03\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w5\n"
"vmla.f32 q11, q6, d2[0] @ w1 * " /* mul r1 with w3, w4, w5, get out r0 */
"inr04\n" "vmla.f32 q8, q5, d3[0] @ w3 * inr10\n"
"vld1.32 {d5}, [%[r1]] @ load r0, " "vmla.f32 q9, q5, d3[1] @ w3 * inr11\n"
"2 float\n" "vmla.f32 q10, q5, d4[0] @ w3 * inr12\n"
"vmla.f32 q8, q7, d1[0] @ w2 * " "vmla.f32 q11, q5, d4[1] @ w3 * inr13\n"
"inr02\n" "vld1.32 {d0-d1}, [%[r2]]! @ load r2\n"
"vmla.f32 q9, q7, d1[1] @ w2 * " "vmla.f32 q8, q6, d3[1] @ w4 * inr11\n"
"inr03\n" "vmla.f32 q9, q6, d4[0] @ w4 * inr12\n"
"vmla.f32 q10, q7, d2[0] @ w2 * " "vmla.f32 q10, q6, d4[1] @ w4 * inr13\n"
"inr04\n" "vmla.f32 q11, q6, d5[0] @ w4 * inr14\n"
"vmla.f32 q11, q7, d2[1] @ w2 * " "vld1.32 {d2}, [%[r2]] @ load r2\n"
"inr05\n" "vmla.f32 q8, q7, d4[0] @ w5 * inr12\n"
"vmla.f32 q9, q7, d4[1] @ w5 * inr13\n"
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " "vmla.f32 q10, q7, d5[0] @ w5 * inr14\n"
"- 32, to start address\n" "vmla.f32 q11, q7, d5[1] @ w5 * inr15\n"
/* mul r2 with w3, w4, w5, get out r1 */
/* mul r1 with w0, w1, w2, get out r1 */ "vmla.f32 q12, q5, d0[0] @ w3 * inr20\n"
"vmla.f32 q12, q5, d3[0] @ w0 * " "vmla.f32 q13, q5, d0[1] @ w3 * inr21\n"
"inr10\n" "vmla.f32 q14, q5, d1[0] @ w3 * inr22\n"
"vmla.f32 q13, q5, d3[1] @ w0 * " "vmla.f32 q15, q5, d1[1] @ w3 * inr23\n"
"inr11\n" "vmla.f32 q12, q6, d0[1] @ w4 * inr21\n"
"vmla.f32 q14, q5, d4[0] @ w0 * " "vmla.f32 q13, q6, d1[0] @ w4 * inr22\n"
"inr12\n" "vmla.f32 q14, q6, d1[1] @ w4 * inr23\n"
"vmla.f32 q15, q5, d4[1] @ w0 * " "vmla.f32 q15, q6, d2[0] @ w4 * inr24\n"
"inr13\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n"
"vmla.f32 q12, q6, d3[1] @ w1 * " "vmla.f32 q12, q7, d1[0] @ w5 * inr22\n"
"inr11\n" "vmla.f32 q13, q7, d1[1] @ w5 * inr23\n"
"vmla.f32 q13, q6, d4[0] @ w1 * " "vmla.f32 q14, q7, d2[0] @ w5 * inr24\n"
"inr12\n" "vmla.f32 q15, q7, d2[1] @ w5 * inr25\n"
"vmla.f32 q14, q6, d4[1] @ w1 * " "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n"
"inr13\n" "sub %[wc0], %[wc0], #144 @ wc0 - 144\n"
"vmla.f32 q15, q6, d5[0] @ w1 * " /* mul r2 with w6, w7, w8, get out r0 */
"inr14\n" "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " "vmla.f32 q9, q5, d0[1] @ w6 * inr21\n"
"w4, to q5, q6\n" "vld1.32 {d3-d4}, [%[r3]]! @ load r3\n"
"vmla.f32 q12, q7, d4[0] @ w2 * " "vmla.f32 q10, q5, d1[0] @ w6 * inr22\n"
"inr12\n" "vmla.f32 q11, q5, d1[1] @ w6 * inr23\n"
"vmla.f32 q13, q7, d4[1] @ w2 * " "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n"
"inr13\n" "vmla.f32 q9, q6, d1[0] @ w7 * inr22\n"
"vmla.f32 q14, q7, d5[0] @ w2 * " "vld1.32 {d5}, [%[r3]] @ load r3\n"
"inr14\n" "vmla.f32 q10, q6, d1[1] @ w7 * inr23\n"
"vmla.f32 q15, q7, d5[1] @ w2 * " "vmla.f32 q11, q6, d2[0] @ w7 * inr24\n"
"inr15\n" "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " "vmla.f32 q9, q7, d1[1] @ w8 * inr23\n"
"to q7\n" "vld1.32 {d0-d1}, [%[r0]]! @ load r0\n"
"vmla.f32 q10, q7, d2[0] @ w8 * inr24\n"
/* mul r1 with w3, w4, w5, get out r0 */ "vmla.f32 q11, q7, d2[1] @ w8 * inr25\n"
"vmla.f32 q8, q5, d3[0] @ w3 * " "vld1.32 {d2}, [%[r0]] @ load r0\n"
"inr10\n" /* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q9, q5, d3[1] @ w3 * " "vmla.f32 q12, q5, d3[0] @ w6 * inr20\n"
"inr11\n" "vmla.f32 q13, q5, d3[1] @ w6 * inr21\n"
"vmla.f32 q10, q5, d4[0] @ w3 * " "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n"
"inr12\n" "vmla.f32 q14, q5, d4[0] @ w6 * inr22\n"
"vmla.f32 q11, q5, d4[1] @ w3 * " "vmla.f32 q15, q5, d4[1] @ w6 * inr23\n"
"inr13\n" "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n"
"vld1.32 {d0-d1}, [%[r2]]! @ load r2, " "vmla.f32 q12, q6, d3[1] @ w7 * inr21\n"
"4 float\n" "vmla.f32 q13, q6, d4[0] @ w7 * inr22\n"
"vmla.f32 q8, q6, d3[1] @ w4 * " "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"inr11\n" "vmla.f32 q14, q6, d4[1] @ w7 * inr23\n"
"vmla.f32 q9, q6, d4[0] @ w4 * " "vmla.f32 q15, q6, d5[0] @ w7 * inr24\n"
"inr12\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vmla.f32 q10, q6, d4[1] @ w4 * " "vmla.f32 q12, q7, d4[0] @ w8 * inr22\n"
"inr13\n" "vmla.f32 q13, q7, d4[1] @ w8 * inr23\n"
"vmla.f32 q11, q6, d5[0] @ w4 * " "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"inr14\n" "vmla.f32 q14, q7, d5[0] @ w8 * inr24\n"
"vld1.32 {d2}, [%[r2]] @ load r2, " "vmla.f32 q15, q7, d5[1] @ w8 * inr25\n"
"2 float\n" "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n"
"vmla.f32 q8, q7, d4[0] @ w5 * " "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n"
"inr12\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
"vmla.f32 q9, q7, d4[1] @ w5 * " "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
"inr13\n" "subs %[cnt], #1 @ loop count--\n"
"vmla.f32 q10, q7, d5[0] @ w5 * " "bne 0b @ jump to main loop\n"
"inr14\n" : [cnt] "+r"(cnt),
"vmla.f32 q11, q7, d5[1] @ w5 * " [r0] "+r"(r0),[r1] "+r"(r1),
"inr15\n" [r2] "+r"(r2),[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
/* mul r2 with w3, w4, w5, get out r1 */ [ptr_out1] "+r"(ptr_out1),
"vmla.f32 q12, q5, d0[0] @ w3 * " [wc0] "+r"(wc0)
"inr20\n" :
"vmla.f32 q13, q5, d0[1] @ w3 * " : "cc","memory","q0","q1","q2","q3",
"inr21\n" "q4","q5","q6","q7","q8","q9",
"vmla.f32 q14, q5, d1[0] @ w3 * " "q10","q11","q12","q13","q14","q15");
"inr22\n" // clang-format on
"vmla.f32 q15, q5, d1[1] @ w3 * "
"inr23\n"
"vmla.f32 q12, q6, d0[1] @ w4 * "
"inr21\n"
"vmla.f32 q13, q6, d1[0] @ w4 * "
"inr22\n"
"vmla.f32 q14, q6, d1[1] @ w4 * "
"inr23\n"
"vmla.f32 q15, q6, d2[0] @ w4 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d1[0] @ w5 * "
"inr22\n"
"vmla.f32 q13, q7, d1[1] @ w5 * "
"inr23\n"
"vmla.f32 q14, q7, d2[0] @ w5 * "
"inr24\n"
"vmla.f32 q15, q7, d2[1] @ w5 * "
"inr25\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
/* mul r2 with w6, w7, w8, get out r0 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d0[1] @ w6 * "
"inr21\n"
"vld1.32 {d3-d4}, [%[r3]]! @ load r3, "
"4 float\n"
"vmla.f32 q10, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q11, q5, d1[1] @ w6 * "
"inr23\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[0] @ w7 * "
"inr22\n"
"vld1.32 {d5}, [%[r3]] @ load r3, "
"2 float\n"
"vmla.f32 q10, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q11, q6, d2[0] @ w7 * "
"inr24\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d1[1] @ w8 * "
"inr23\n"
"vld1.32 {d0-d1}, [%[r0]]! @ load r0, "
"4 float\n"
"vmla.f32 q10, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q11, q7, d2[1] @ w8 * "
"inr25\n"
"vld1.32 {d2}, [%[r0]] @ load r0, "
"2 float\n"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q5, d3[0] @ w6 * "
"inr20\n"
"vmla.f32 q13, q5, d3[1] @ w6 * "
"inr21\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q14, q5, d4[0] @ w6 * "
"inr22\n"
"vmla.f32 q15, q5, d4[1] @ w6 * "
"inr23\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q12, q6, d3[1] @ w7 * "
"inr21\n"
"vmla.f32 q13, q6, d4[0] @ w7 * "
"inr22\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vmla.f32 q14, q6, d4[1] @ w7 * "
"inr23\n"
"vmla.f32 q15, q6, d5[0] @ w7 * "
"inr24\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d4[0] @ w8 * "
"inr22\n"
"vmla.f32 q13, q7, d4[1] @ w8 * "
"inr23\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"vmla.f32 q14, q7, d5[0] @ w8 * "
"inr24\n"
"vmla.f32 q15, q7, d5[1] @ w8 * "
"inr25\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
[wc0] "+r"(wc0)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
inr2 += win_round; inr2 += win_round;
...@@ -602,7 +458,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -602,7 +458,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
write_to_output_c4_fp32(pre_out, write_to_output_c4_fp32(pre_out,
dout_batch, dout_batch,
c, c,
c + hout_c_block, c + OUT_C_BLOCK,
h, h,
h + h_kernel, h + h_kernel,
0, 0,
...@@ -641,7 +497,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -641,7 +497,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
} }
fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_remain_ptr; const float* wc0 = weight_remain_ptr;
const float* inr0 = block_inr0; const float* inr0 = block_inr0;
...@@ -672,109 +528,66 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -672,109 +528,66 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3; const float* r3 = inr3;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr0, "ldr q21, [%[ptr_out0]]\n" /* load outr0, w0~w3*/
w0~w3*/ "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r1*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ "2: \n" /* main loop*/
"2: \n" /* main loop*/ "fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/
"fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0*/ "ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/
"fmla v22.4s , %[w0].4s, v2.4s \n" /* outr1 = w0 * r1*/ "ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/
"ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/
"ext v8.16b, v0.16b, v1.16b, #4 \n" /* shift r0 left 1*/ "ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/
"ext v10.16b, v2.16b, v3.16b, #4 \n" /* shift r1 left 1*/ "ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/
"ext v9.16b, v0.16b, v1.16b, #8 \n" /* shift r0 left 2*/ "fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/
"ext v11.16b, v2.16b, v3.16b, #8 \n" /* shift r1 left 2*/ "fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/
"fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/
"ldp q0, q1, [%[r0]], #16 \n" /* load input r0*/ "fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/
"fmla v21.4s , %[w1].4s, v8.4s \n" /* outr0 = w1 * r1*/ "fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/
"fmla v22.4s , %[w1].4s, v10.4s \n" /* outr1 = w1 * r2*/ "ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/
"ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/
"fmla v21.4s , %[w2].4s, v9.4s \n" /* outr0 = w2 * r1*/ "ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/
"fmla v22.4s , %[w2].4s, v11.4s \n" /* outr1 = w2 * r2*/ "ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/
"fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1*/ "fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/
"fmla v22.4s , %[w3].4s, v4.4s \n" /* outr1 = w3 * r2*/ "fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/
"fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/
"ext v12.16b, v4.16b, v5.16b, #4\n" /* shift r2 left 1*/ "ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/
"ext v14.16b, v6.16b, v7.16b, #4\n" /* shift r3 left 1*/ "fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/
"ext v13.16b, v4.16b, v5.16b, #8\n" /* shift r2 left 2*/ "fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/
"ext v15.16b, v6.16b, v7.16b, #8\n" /* shift r3 left 2*/ "ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/
"fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/
"fmla v21.4s , %[w4].4s, v10.4s \n" /* outr0 = w4 * r1*/ "fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/
"fmla v22.4s , %[w4].4s, v12.4s \n" /* outr1 = w4 * r2*/ "ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/
"fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/
"fmla v21.4s , %[w5].4s, v11.4s \n" /* outr0 = w5 * r1*/ "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/
"fmla v22.4s , %[w5].4s, v13.4s \n" /* outr1 = w5 * r2*/ "str q21, [%[ptr_out0]], #16 \n" /*write output r0*/
"str q22, [%[ptr_out1]], #16 \n" /*write output r1*/
"ldp q2, q3, [%[r1]], #16 \n" /* load input r0*/ "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2*/ "ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"fmla v22.4s , %[w6].4s, v6.4s \n" /* outr1 = w6 * r3*/ "bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
"ldp q4, q5, [%[r2]], #16 \n" /* load input r2*/ [r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
"fmla v21.4s , %[w7].4s, v12.4s \n" /* outr0 = w7 * r1*/ [ptr_out0] "+r"(ptr_out0),
"fmla v22.4s , %[w7].4s, v14.4s \n" /* outr1 = w7 * r2*/ [ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
"ldp q6, q7, [%[r3]], #16 \n" /* load input r3*/ [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
[w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
"fmla v21.4s , %[w8].4s, v13.4s \n" /* outr0 = w8 * r1*/ : "cc","memory","v0",
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r2*/ "v1","v2","v3","v4","v5","v6",
"v7","v8","v9","v10","v11","v12",
"str q21, [%[ptr_out0]], #16 \n" /*write output r0*/ "v13","v14","v15","v21","v22"
"str q22, [%[ptr_out1]], #16 \n" /*write output r1*/ );
// clang-format on
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ wc0 += 9 * OUT_C_BLOCK;
"ldr q21, [%[ptr_out0]] \n" /* load outr0, w0~w3*/
"ldr q22, [%[ptr_out1]] \n" /* load outr1, w0~w3*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v21",
"v22");
wc0 += 9 * hout_c_block;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
inr2 += win_round; inr2 += win_round;
...@@ -806,181 +619,96 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -806,181 +619,96 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const float* r3 = inr3; const float* r3 = inr3;
int cnt = w_loop / 2; int cnt = w_loop / 2;
if (cnt > 0) { if (cnt > 0) {
// clang-format off
asm volatile( asm volatile(
"vld1.32 {d24-d27}, [%[ptr_out0]] @ " "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n"
"load or00, or01\n" "vld1.32 {d6-d9}, [%[r0]]! @ load r0\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r0, 8 " "vld1.32 {d10}, [%[r0]] @ load r0\n"
"float\n" /* main loop */
"vld1.32 {d10}, [%[r0]] @ load r0, 2 " "0: @ main loop\n"
"float\n" /* r0 * w0, w1, w2, get out r0*/
/* main loop */ "vld1.32 {d28-d31}, [%[ptr_out1]]@ load or10 or11\n"
"0: @ main loop\n" "vext.32 q8, q3, q4, #1 @ r0, shift left 1\n"
/* r0 * w0, w1, w2, get out r0*/ "vext.32 q9, q4, q5, #1 @ r0, shift left 1\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " "vmla.f32 q12, q3, %e[w0][0] @ w00 * r0\n"
"or11\n" "vmla.f32 q13, q4, %e[w0][0] @ w00 * r0\n"
"vext.32 q8, q3, q4, #1 @ r0, shift " "vext.32 q10, q3, q4, #2 @ r0, shift left 2\n"
"left 1, get 1, 2, 3, 4\n" "vext.32 q11, q4, q5, #2 @ r0, shift left 2\n"
"vext.32 q9, q4, q5, #1 @ r0, shift " "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n"
"left 1, get 5, 6, 7, 8\n" "vmla.f32 q13, q9, %e[w0][1] @ w01 * r0\n"
"vmla.f32 q12, q3, %e[w0][0] @ w00 * r0, " "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8\n"
"0, 1, 2, 3\n" "vmla.f32 q12, q10, %f[w0][0] @ w02 * r0\n"
"vmla.f32 q13, q4, %e[w0][0] @ w00 * r0, " "vmla.f32 q13, q11, %f[w0][0] @ w02 * r0\n"
"4, 5, 6, 7\n" "vld1.32 {d10}, [%[r1]] @ load r1\n"
"vext.32 q10, q3, q4, #2 @ r0, shift " /* r1 * w3, w4, w5, get out r0*/
"left 2, get 2, 3, 4, 5\n" /* r1 * w0, w1, w2, get out r1*/
"vext.32 q11, q4, q5, #2 @ r0, shift " "vmla.f32 q12, q3, %e[w1][0] @ w10 * r1\n"
"left 2, get 6, 7, 8, 9\n" "vmla.f32 q13, q4, %e[w1][0] @ w10 * r1\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " "vext.32 q8, q3, q4, #1 @ r1, shift left 1\n"
"1, 2, 3, 4\n" "vext.32 q9, q4, q5, #1 @ r1, shift left 1\n"
"vmla.f32 q13, q9, %e[w0][1] @ w01 * r0, " "vmla.f32 q14, q3, %e[w0][0] @ w00 * r1\n"
"5, 6, 7, 8\n" "vmla.f32 q15, q4, %e[w0][0] @ w00 * r1\n"
"vld1.32 {d6-d9}, [%[r1]]! @ load r1, 8 " "vext.32 q10, q3, q4, #2 @ r1, shift left 2\n"
"float\n" "vext.32 q11, q4, q5, #2 @ r1, shift left 2\n"
"vmla.f32 q12, q10, %f[w0][0] @ w02 * r0, " "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n"
"2, 3, 4, 5\n" "vmla.f32 q13, q9, %e[w1][1] @ w11 * r1\n"
"vmla.f32 q13, q11, %f[w0][0] @ w02 * r0, " "vmla.f32 q14, q8, %e[w0][1] @ w01 * r1\n"
"6, 7, 8, 9\n" "vmla.f32 q15, q9, %e[w0][1] @ w01 * r1\n"
"vld1.32 {d10}, [%[r1]] @ load r1, 2 " "vld1.32 {d6-d9}, [%[r2]]! @ load r2\n"
"float\n" "vmla.f32 q12, q10, %f[w1][0] @ w12 * r1\n"
"vmla.f32 q13, q11, %f[w1][0] @ w12 * r1\n"
/* r1 * w3, w4, w5, get out r0*/ "vmla.f32 q14, q10, %f[w0][0] @ w02 * r1\n"
/* r1 * w0, w1, w2, get out r1*/ "vmla.f32 q15, q11, %f[w0][0] @ w02 * r1\n"
"vmla.f32 q12, q3, %e[w1][0] @ w10 * r1, " "vld1.32 {d10}, [%[r2]] @ load r2\n"
"0, 1, 2, 3\n" /* r2 * w6, w7, w8, get out r0*/
"vmla.f32 q13, q4, %e[w1][0] @ w10 * r1, " /* r2 * w3, w4, w5, get out r1*/
"4, 5, 6, 7\n" "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n"
"vext.32 q8, q3, q4, #1 @ r1, shift " "vmla.f32 q13, q4, %e[w2][0] @ w20 * r2\n"
"left 1, get 1, 2, 3, 4\n" "vext.32 q8, q3, q4, #1 @ r2, shift left 1\n"
"vext.32 q9, q4, q5, #1 @ r1, shift " "vext.32 q9, q4, q5, #1 @ r2, shift left 1\n"
"left 1, get 5, 6, 7, 8\n" "vmla.f32 q14, q3, %e[w1][0] @ w10 * r2\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r1, " "vmla.f32 q15, q4, %e[w1][0] @ w10 * r2\n"
"0, 1, 2, 3\n" "vext.32 q10, q3, q4, #2 @ r2, shift left 2\n"
"vmla.f32 q15, q4, %e[w0][0] @ w00 * r1, " "vext.32 q11, q4, q5, #2 @ r2, shift left 2\n"
"4, 5, 6, 7\n" "vmla.f32 q12, q8, %e[w2][1] @ w21 * r2\n"
"vext.32 q10, q3, q4, #2 @ r1, shift " "vmla.f32 q13, q9, %e[w2][1] @ w21 * r2\n"
"left 2, get 2, 3, 4, 5\n" "vmla.f32 q14, q8, %e[w1][1] @ w11 * r2\n"
"vext.32 q11, q4, q5, #2 @ r1, shift " "vmla.f32 q15, q9, %e[w1][1] @ w11 * r2\n"
"left 2, get 6, 7, 8, 9\n" "vld1.32 {d6-d9}, [%[r3]]! @ load r3\n"
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, " "vmla.f32 q12, q10, %f[w2][0] @ w22 * r2\n"
"1, 2, 3, 4\n" "vmla.f32 q13, q11, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q13, q9, %e[w1][1] @ w11 * r1, " "vmla.f32 q14, q10, %f[w1][0] @ w12 * r2\n"
"5, 6, 7, 8\n" "vmla.f32 q15, q11, %f[w1][0] @ w12 * r2\n"
"vmla.f32 q14, q8, %e[w0][1] @ w01 * r1, " "vld1.32 {d10}, [%[r3]] @ load r3\n"
"1, 2, 3, 4\n" /* r3 * w6, w7, w8, get out r1*/
"vmla.f32 q15, q9, %e[w0][1] @ w01 * r1, " "vext.32 q8, q3, q4, #1 @ r3, shift left 1\n"
"5, 6, 7, 8\n" "vext.32 q9, q4, q5, #1 @ r3, shift left 1\n"
"vld1.32 {d6-d9}, [%[r2]]! @ load r2, 8 " "vmla.f32 q14, q3, %e[w2][0] @ w20 * r3\n"
"float\n" "vmla.f32 q15, q4, %e[w2][0] @ w20 * r3\n"
"vmla.f32 q12, q10, %f[w1][0] @ w12 * r1, " "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, or01\n"
"2, 3, 4, 5\n" "vext.32 q10, q3, q4, #2 @ r3, shift left 2\n"
"vmla.f32 q13, q11, %f[w1][0] @ w12 * r1, " "vext.32 q11, q4, q5, #2 @ r3, shift left 2\n"
"6, 7, 8, 9\n" "vmla.f32 q14, q8, %e[w2][1] @ w21 * r3\n"
"vmla.f32 q14, q10, %f[w0][0] @ w02 * r1, " "vmla.f32 q15, q9, %e[w2][1] @ w21 * r3\n"
"2, 3, 4, 5\n" "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00,or01\n"
"vmla.f32 q15, q11, %f[w0][0] @ w02 * r1, " "vld1.32 {d6-d9}, [%[r0]]! @ load r3\n"
"6, 7, 8, 9\n" "vmla.f32 q14, q10, %f[w2][0] @ w22 * r3\n"
"vld1.32 {d10}, [%[r2]] @ load r2, 2 " "vmla.f32 q15, q11, %f[w2][0] @ w22 * r3\n"
"float\n" "vld1.32 {d10}, [%[r0]] @ load r0\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, or11\n"
/* r2 * w6, w7, w8, get out r0*/ "subs %[cnt], #1 @ loop count -1\n"
/* r2 * w3, w4, w5, get out r1*/ "bne 0b @ jump to main loop\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " : [cnt] "+r"(cnt),
"0, 1, 2, 3\n" [r0] "+r"(r0),[r1] "+r"(r1),
"vmla.f32 q13, q4, %e[w2][0] @ w20 * r2, " [r2] "+r"(r2),[r3] "+r"(r3),
"4, 5, 6, 7\n" [ptr_out0] "+r"(ptr_out0),
"vext.32 q8, q3, q4, #1 @ r2, shift " [ptr_out1] "+r"(ptr_out1)
"left 1, get 1, 2, 3, 4\n" : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
"vext.32 q9, q4, q5, #1 @ r2, shift " : "cc","memory","q3","q4",
"left 1, get 5, 6, 7, 8\n" "q5","q6","q7","q8","q9","q10",
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r2, " "q11","q12","q13","q14","q15"
"0, 1, 2, 3\n" );
"vmla.f32 q15, q4, %e[w1][0] @ w10 * r2, " // clang-format on
"4, 5, 6, 7\n"
"vext.32 q10, q3, q4, #2 @ r2, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r2, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q12, q8, %e[w2][1] @ w21 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q13, q9, %e[w2][1] @ w21 * r2, "
"5, 6, 7, 8\n"
"vmla.f32 q14, q8, %e[w1][1] @ w11 * r2, "
"1, 2, 3, 4\n"
"vmla.f32 q15, q9, %e[w1][1] @ w11 * r2, "
"5, 6, 7, 8\n"
"vld1.32 {d6-d9}, [%[r3]]! @ load r3, 8 "
"float\n"
"vmla.f32 q12, q10, %f[w2][0] @ w22 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q13, q11, %f[w2][0] @ w22 * r2, "
"6, 7, 8, 9\n"
"vmla.f32 q14, q10, %f[w1][0] @ w12 * r2, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w1][0] @ w12 * r2, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r3]] @ load r3, 2 "
"float\n"
/* r3 * w6, w7, w8, get out r1*/
"vext.32 q8, q3, q4, #1 @ r3, shift "
"left 1, get 1, 2, 3, 4\n"
"vext.32 q9, q4, q5, #1 @ r3, shift "
"left 1, get 5, 6, 7, 8\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q4, %e[w2][0] @ w20 * r3, "
"4, 5, 6, 7\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or00, "
"or01\n"
"vext.32 q10, q3, q4, #2 @ r3, shift "
"left 2, get 2, 3, 4, 5\n"
"vext.32 q11, q4, q5, #2 @ r3, shift "
"left 2, get 6, 7, 8, 9\n"
"vmla.f32 q14, q8, %e[w2][1] @ w21 * r3, "
"0, 1, 2, 3\n"
"vmla.f32 q15, q9, %e[w2][1] @ w21 * r3, "
"4, 5, 6, 7\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, "
"or01\n"
"vld1.32 {d6-d9}, [%[r0]]! @ load r3, 8 "
"float\n"
"vmla.f32 q14, q10, %f[w2][0] @ w22 * r3, "
"2, 3, 4, 5\n"
"vmla.f32 q15, q11, %f[w2][0] @ w22 * r3, "
"6, 7, 8, 9\n"
"vld1.32 {d10}, [%[r0]] @ load r0, 2 "
"float\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or10, "
"or11\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
r0 -= 8; r0 -= 8;
} }
//! deal with remain ow //! deal with remain ow
......
...@@ -24,6 +24,39 @@ namespace lite { ...@@ -24,6 +24,39 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
const int OUT_C_BLOCK = 4;
const int OUT_H_BLOCK = 2;
const int OUT_W_BLOCK = 4;
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx) {
auto dim_in = param.x->dims();
auto dim_out = param.output->dims();
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0];
int ow = dim_out[3];
int oh = dim_out[2];
int ic = dim_in[1];
const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int win_round = wout_round * 2 /*stride_w*/ + 1;
const int hin_r_block = OUT_H_BLOCK * 2 /*stride_h*/ + 1;
int hout_r_block =
(llc_size - 2 * wout_round * ic - ic) /
((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
return sizeof(float) * (pre_in_size + ctx->threads() * pre_out_size);
}
void conv_3x3s2_direct_fp32(const float* i_data, void conv_3x3s2_direct_fp32(const float* i_data,
float* o_data, float* o_data,
int bs, int bs,
...@@ -44,53 +77,50 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -44,53 +77,50 @@ void conv_3x3s2_direct_fp32(const float* i_data,
int l2_size = ctx->llc_size() / sizeof(float); int l2_size = ctx->llc_size() / sizeof(float);
const int pad_w = param.paddings[1]; const int pad_w = param.paddings[1];
const int pad_h = param.paddings[0]; const int pad_h = param.paddings[0];
const int hout_c_block = 4; const int wout_round = ROUNDUP(ow, OUT_W_BLOCK);
const int hout_r_kernel = 2;
const int wout_block = 4;
const int wout_round = ((ow + wout_block - 1) / wout_block) * wout_block;
const int win_round = wout_round * 2 /*stride_w*/ + 1; const int win_round = wout_round * 2 /*stride_w*/ + 1;
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
//! get h block //! get h block
//! win_round * ic * hin_r_block + wout_round * hout_c_block * hout_r_block //! win_round * ic * hin_r_block + wout_round * OUT_C_BLOCK * hout_r_block
//! * threads = l2_size //! * threads = l2_size
//! win_round = 2 * wout_round + 1 //! win_round = 2 * wout_round + 1
//! hin_r_block = 2 * hout_r_block + 1 //! hin_r_block = 2 * hout_r_block + 1
int hout_r_block = int hout_r_block =
(l2_size - 2 * wout_round * ic - ic) / (l2_size - 2 * wout_round * ic - ic) /
((4 * wout_round + 2) * ic + wout_round * hout_c_block * threads); ((4 * wout_round + 2) * ic + wout_round * OUT_C_BLOCK * threads);
hout_r_block = hout_r_block > oh ? oh : hout_r_block; hout_r_block = hout_r_block > oh ? oh : hout_r_block;
hout_r_block = (hout_r_block / hout_r_kernel) * hout_r_kernel; hout_r_block = (hout_r_block / OUT_H_BLOCK) * OUT_H_BLOCK;
hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; hout_r_block = hout_r_block < OUT_H_BLOCK ? OUT_H_BLOCK : hout_r_block;
const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1; const int hin_r_block = hout_r_block * 2 /*stride_h*/ + 1;
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = OUT_C_BLOCK * hout_r_block * wout_round;
float* tmp_work_space = ctx->workspace_data<float>(); float* tmp_work_space = ctx->workspace_data<float>();
float ptr_zero[win_round]; // NOLINT float ptr_zero[win_round]; // NOLINT
memset(ptr_zero, 0, sizeof(float) * win_round); memset(ptr_zero, 0, sizeof(float) * win_round);
float ptr_write[wout_round]; // NOLINT float ptr_write[wout_round]; // NOLINT
int in_len = win_round * ic;
int pre_in_size = hin_r_block * in_len;
int pre_out_size = hout_c_block * hout_r_block * wout_round;
//! l2_cache start //! l2_cache start
float* pre_din = tmp_work_space; float* pre_din = tmp_work_space;
int size_in_channel = win * ih; int size_in_channel = win * ih;
int size_out_channel = ow * oh; int size_out_channel = ow * oh;
int w_stride = ic * 9; /*kernel_w * kernel_h*/ int w_stride = ic * 9; /*kernel_w * kernel_h*/
int w_stride_chin = hout_c_block * 9; // kernel_w * kernel_h * int w_stride_chin = OUT_C_BLOCK * 9; // kernel_w * kernel_h *
int ws = -pad_w; int ws = -pad_w;
int we = ws + win_round; int we = ws + win_round;
int w_loop = wout_round / 4; int w_loop = wout_round / 4;
int c_remain = oc - (oc / hout_c_block) * hout_c_block; int c_remain = oc - (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int c_round_down = (oc / hout_c_block) * hout_c_block; int c_round_down = (oc / OUT_C_BLOCK) * OUT_C_BLOCK;
int out_row_stride = hout_c_block * wout_round; int out_row_stride = OUT_C_BLOCK * wout_round;
for (int n = 0; n < bs; ++n) { for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel; const float* din_batch = i_data + n * ic * size_in_channel;
...@@ -114,7 +144,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -114,7 +144,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* cblock_inr4 = cblock_inr3 + in_len; const float* cblock_inr4 = cblock_inr3 + in_len;
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
for (int c = 0; c < c_round_down; c += hout_c_block) { for (int c = 0; c < c_round_down; c += OUT_C_BLOCK) {
#ifdef ARM_WITH_OMP #ifdef ARM_WITH_OMP
float* pre_out = float* pre_out =
pre_din + pre_in_size + omp_get_thread_num() * pre_out_size; pre_din + pre_in_size + omp_get_thread_num() * pre_out_size;
...@@ -133,9 +163,9 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -133,9 +163,9 @@ void conv_3x3s2_direct_fp32(const float* i_data,
bias_ptr = bias + c; bias_ptr = bias + c;
} }
fill_packed_biasc4( fill_packed_biasc4(
pre_out, bias_ptr, wout_round * hout_c_block * h_kernel); pre_out, bias_ptr, wout_round * OUT_C_BLOCK * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c; const float* wc0 = weight_c;
const float* inr0 = block_inr0; const float* inr0 = block_inr0;
...@@ -168,205 +198,133 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -168,205 +198,133 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4; const float* r4 = inr4;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, "ldp q15, q16, [%[ptr_out0]]\n" /* load outr00, outr01*/
outr01*/ "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/ "ldp q0, q1, [%[r0]], #32\n" /* load input r0*/
"ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ "ldp q4, q5, [%[r2]], #32\n" /* load input r2*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
element*/ "2:\n" /* main loop*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ /* r0, r2, mul w0, get out r0, r1 */
"ldr d12, [%[r2]] \n" /* load input r2, 9th "ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/
element*/ "ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/
"2: \n" /* main loop*/ "fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/
/* r0, r2, mul w0, get out r0, r1 */ "fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/
"ldp q19, q20, [%[ptr_out1]] \n" /* load outr10, outr11*/ "fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/
"ldp q21, q22, [%[ptr_out1], #32]\n" /* load outr12, outr13*/ "fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/
"fmla v15.4s , %[w0].4s, v0.s[0]\n" /* outr00 = w0 * r0[0]*/ "fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/
"fmla v16.4s , %[w0].4s, v0.s[2]\n" /* outr01 = w0 * r0[2]*/ "fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/
"fmla v17.4s , %[w0].4s, v1.s[0]\n" /* outr02 = w0 * r0[4]*/ "fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/
"fmla v18.4s , %[w0].4s, v1.s[2]\n" /* outr03 = w0 * r0[6]*/ "fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/
"fmla v19.4s , %[w0].4s, v4.s[0]\n" /* outr10 = w0 * r2[0]*/ "ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/
"fmla v20.4s , %[w0].4s, v4.s[2]\n" /* outr11 = w0 * r2[2]*/ /* r2 mul w6, get out r0*/
"fmla v21.4s , %[w0].4s, v5.s[0]\n" /* outr12 = w0 * r2[4]*/ "fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/
"fmla v22.4s , %[w0].4s, v5.s[2]\n" /* outr13 = w0 * r2[6]*/ "fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/
"fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/
"ldp q2, q3, [%[r1]], #32 \n" /* load input r1*/ "fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/
"ldr d11, [%[r1]]\n" /* load input r1, 9th element*/
/* r2 mul w6, get out r0*/ /* r0, r2, mul w1, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v4.s[0]\n" /* outr00 = w6 * r2[0]*/ "fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/
"fmla v16.4s , %[w6].4s, v4.s[2]\n" /* outr01 = w6 * r2[2]*/ "fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/
"fmla v17.4s , %[w6].4s, v5.s[0]\n" /* outr02 = w6 * r2[4]*/ "fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/
"fmla v18.4s , %[w6].4s, v5.s[2]\n" /* outr03 = w6 * r2[6]*/ "fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/
"fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th "fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/
element*/ "fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/
"fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/
/* r0, r2, mul w1, get out r0, r1 */ "ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/
"fmla v15.4s , %[w1].4s, v0.s[1]\n" /* outr00 = w1 * r0[1]*/ /* r2 mul w7, get out r0 */
"fmla v16.4s , %[w1].4s, v0.s[3]\n" /* outr01 = w1 * r0[3]*/ "fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/
"fmla v17.4s , %[w1].4s, v1.s[1]\n" /* outr02 = w1 * r0[5]*/ "fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/
"fmla v18.4s , %[w1].4s, v1.s[3]\n" /* outr03 = w1 * r0[7]*/ "fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/
"fmla v19.4s , %[w1].4s, v4.s[1]\n" /* outr10 = w1 * r2[1]*/ "fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/
"fmla v20.4s , %[w1].4s, v4.s[3]\n" /* outr11 = w1 * r2[3]*/ "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/
"fmla v21.4s , %[w1].4s, v5.s[1]\n" /* outr12 = w1 * r2[5]*/ /* r0, r2, mul w2, get out r0, r1 */
"fmla v22.4s , %[w1].4s, v5.s[3]\n" /* outr13 = w1 * r2[7]*/ "fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/
"ldp q6, q7, [%[r3]], #32 \n" /* load input r3*/ "fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/
"fmla v18.4s , %[w2].4s, v10.s[0]\n"/* outr03 = w2 * r0[8]*/
/* r2 mul w7, get out r0 */ "fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/
"fmla v15.4s , %[w7].4s, v4.s[1]\n" /* outr00 = w7 * r2[1]*/ "fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/
"fmla v16.4s , %[w7].4s, v4.s[3]\n" /* outr01 = w7 * r2[3]*/ "fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/
"fmla v17.4s , %[w7].4s, v5.s[1]\n" /* outr02 = w7 * r2[5]*/ "fmla v22.4s , %[w2].4s, v12.s[0]\n"/* outr13 = w2 * r2[8]*/
"fmla v18.4s , %[w7].4s, v5.s[3]\n" /* outr03 = w7 * r2[7]*/ "ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */
"ldr d13, [%[r3]] \n" /* load input r3, 9th "fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/
element*/ "fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/
"fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/
/* r0, r2, mul w2, get out r0, r1 */ "fmla v18.4s , %[w8].4s, v12.s[0]\n"/* outr03 = w8 * r2[8]*/
"fmla v15.4s , %[w2].4s, v0.s[2]\n" /* outr00 = w2 * r0[2]*/ "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/
"fmla v16.4s , %[w2].4s, v1.s[0]\n" /* outr01 = w2 * r0[4]*/ /* r1, r3, mul w3, get out r0, r1 */
"fmla v17.4s , %[w2].4s, v1.s[2]\n" /* outr02 = w2 * r0[6]*/ "fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/
"fmla v18.4s , %[w2].4s, v10.s[0]\n" /* outr03 = w2 * "fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/
r0[8]*/ "fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/
"fmla v19.4s , %[w2].4s, v4.s[2]\n" /* outr10 = w2 * r2[2]*/ "fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/
"fmla v20.4s , %[w2].4s, v5.s[0]\n" /* outr11 = w2 * r2[4]*/ "fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/
"fmla v21.4s , %[w2].4s, v5.s[2]\n" /* outr12 = w2 * r2[6]*/ "fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/
"fmla v22.4s , %[w2].4s, v12.s[0]\n" /* outr13 = w2 * "fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/
r2[8]*/ "fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/
"ldp q8, q9, [%[r4]], #32 \n" /* load input r4*/ /* r1, r3, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/
/* r2, mul w8, get out r0 */ "fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/
"fmla v15.4s , %[w8].4s, v4.s[2]\n" /* outr00 = w8 * r2[2]*/ "fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/
"fmla v16.4s , %[w8].4s, v5.s[0]\n" /* outr01 = w8 * r2[4]*/ "fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/
"fmla v17.4s , %[w8].4s, v5.s[2]\n" /* outr02 = w8 * r2[6]*/ "fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/
"fmla v18.4s , %[w8].4s, v12.s[0]\n" /* outr03 = w8 * "fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/
r2[8]*/ "fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/
"fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
element*/ /* r1, r3, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/
/* r1, r3, mul w3, get out r0, r1 */ "fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/
"fmla v15.4s , %[w3].4s, v2.s[0]\n" /* outr00 = w3 * r1[0]*/ "fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/
"fmla v16.4s , %[w3].4s, v2.s[2]\n" /* outr01 = w3 * r1[2]*/ "fmla v18.4s , %[w5].4s, v11.s[0]\n"/* outr03 = w5 * r1[8]*/
"fmla v17.4s , %[w3].4s, v3.s[0]\n" /* outr02 = w3 * r1[4]*/ "ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/
"fmla v18.4s , %[w3].4s, v3.s[2]\n" /* outr03 = w3 * r1[6]*/ "stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/
"fmla v19.4s , %[w3].4s, v6.s[0]\n" /* outr10 = w3 * r3[0]*/ "fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/
"fmla v20.4s , %[w3].4s, v6.s[2]\n" /* outr11 = w3 * r3[2]*/ "fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/
"fmla v21.4s , %[w3].4s, v7.s[0]\n" /* outr12 = w3 * r3[4]*/ "fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/
"fmla v22.4s , %[w3].4s, v7.s[2]\n" /* outr13 = w3 * r3[6]*/ "fmla v22.4s , %[w5].4s, v13.s[0]\n"/* outr13 = w5 * r3[8]*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"ldp q0, q1, [%[r0]], #32 \n" /* load input r0*/ "stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/
/* r4, mul w6, get out r1 */
/* r1, r3, mul w4, get out r0, r1 */ "fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/
"fmla v15.4s , %[w4].4s, v2.s[1]\n" /* outr00 = w4 * r1[1]*/ "fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/
"fmla v16.4s , %[w4].4s, v2.s[3]\n" /* outr01 = w4 * r1[3]*/ "fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/
"fmla v17.4s , %[w4].4s, v3.s[1]\n" /* outr02 = w4 * r1[5]*/ "fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/
"fmla v18.4s , %[w4].4s, v3.s[3]\n" /* outr03 = w4 * r1[7]*/ "ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
"fmla v19.4s , %[w4].4s, v6.s[1]\n" /* outr10 = w4 * r3[1]*/ /* r4, mul w7, get out r1 */
"fmla v20.4s , %[w4].4s, v6.s[3]\n" /* outr11 = w4 * r3[3]*/ "fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/
"fmla v21.4s , %[w4].4s, v7.s[1]\n" /* outr12 = w4 * r3[5]*/ "fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/
"fmla v22.4s , %[w4].4s, v7.s[3]\n" /* outr13 = w4 * r3[7]*/ "fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/
"fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th "ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
element*/ /* r4, mul w8, get out r1 */
"fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/
/* r1, r3, mul w5, get out r0, r1 */ "fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/
"fmla v15.4s , %[w5].4s, v2.s[2]\n" /* outr00 = w5 * r1[2]*/ "fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/
"fmla v16.4s , %[w5].4s, v3.s[0]\n" /* outr01 = w5 * r1[4]*/ "fmla v22.4s , %[w8].4s, v14.s[0]\n"/* outr13 = w8 * r4[8]*/
"fmla v17.4s , %[w5].4s, v3.s[2]\n" /* outr02 = w5 * r1[6]*/ "subs %w[cnt], %w[cnt], #1\n" /*loop count -1*/
"fmla v18.4s , %[w5].4s, v11.s[0]\n" /* outr03 = w5 * "stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
r1[8]*/ "stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
"ldp q4, q5, [%[r2]], #32 \n" /* load input r2*/ : [cnt] "+r"(cnt), [r0] "+r"(r0), [r1] "+r"(r1),
"stp q15, q16, [%[ptr_out0]], #32\n" /* save outr00, outr01*/ [r2] "+r"(r2),[r3] "+r"(r3), [r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
"fmla v19.4s , %[w5].4s, v6.s[2]\n" /* outr10 = w5 * r3[2]*/ [ptr_out1] "+r"(ptr_out1)
"fmla v20.4s , %[w5].4s, v7.s[0]\n" /* outr11 = w5 * r3[4]*/ : [w0] "w"(w0),
"fmla v21.4s , %[w5].4s, v7.s[2]\n" /* outr12 = w5 * r3[6]*/ [w1] "w"(w1), [w2] "w"(w2),
"fmla v22.4s , %[w5].4s, v13.s[0]\n" /* outr13 = w5 * [w3] "w"(w3), [w4] "w"(w4),
r3[8]*/ [w5] "w"(w5), [w6] "w"(w6),
[w7] "w"(w7), [w8] "w"(w8)
"ldr d12, [%[r2]] \n" /* load input r2, 9th : "cc","memory","v0","v1","v2","v3","v4",
element*/ "v5","v6","v7","v8","v9","v10","v11","v12","v13",
"stp q17, q18, [%[ptr_out0]], #32\n" /* save outr02, outr03*/ "v14","v15","v16","v17","v18","v19","v20","v21","v22");
// clang-format on
/* r4, mul w6, get out r1 */ wc0 += 9 * OUT_C_BLOCK;
"fmla v19.4s , %[w6].4s, v8.s[0]\n" /* outr10 = w6 * r4[0]*/
"fmla v20.4s , %[w6].4s, v8.s[2]\n" /* outr11 = w6 * r4[2]*/
"fmla v21.4s , %[w6].4s, v9.s[0]\n" /* outr12 = w6 * r4[4]*/
"fmla v22.4s , %[w6].4s, v9.s[2]\n" /* outr13 = w6 * r4[6]*/
"ldp q15, q16, [%[ptr_out0]] \n" /* load outr00, outr01*/
/* r4, mul w7, get out r1 */
"fmla v19.4s , %[w7].4s, v8.s[1]\n" /* outr10 = w7 * r4[1]*/
"fmla v20.4s , %[w7].4s, v8.s[3]\n" /* outr11 = w7 * r4[3]*/
"fmla v21.4s , %[w7].4s, v9.s[1]\n" /* outr12 = w7 * r4[5]*/
"fmla v22.4s , %[w7].4s, v9.s[3]\n" /* outr13 = w7 * r4[7]*/
"ldp q17, q18, [%[ptr_out0], #32]\n" /* load outr02, outr03*/
/* r4, mul w8, get out r1 */
"fmla v19.4s , %[w8].4s, v8.s[2]\n" /* outr10 = w8 * r4[2]*/
"fmla v20.4s , %[w8].4s, v9.s[0]\n" /* outr11 = w8 * r4[4]*/
"fmla v21.4s , %[w8].4s, v9.s[2]\n" /* outr12 = w8 * r4[6]*/
"fmla v22.4s , %[w8].4s, v14.s[0]\n" /* outr13 = w8 *
r4[8]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"stp q19, q20, [%[ptr_out1]], #32\n" /* save outr10, outr11*/
"stp q21, q22, [%[ptr_out1]], #32\n" /* save outr12, outr13*/
"bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
wc0 += 9 * hout_c_block;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
inr2 += win_round; inr2 += win_round;
...@@ -387,285 +345,142 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -387,285 +345,142 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4; const float* r4 = inr4;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ " "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"load outr0, w0, w1, c0~c3\n" "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load " /* load weights */
"outr0, w2, w3, c0~c3\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
/* load weights */ /* load r0, r2 */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, " "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n"
"w1, to q5, q6\n" "vld1.32 {d8}, [%[r0]] @ load r0\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, " "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 -32\n"
"to q7\n" /* main loop */
"0: @ main loop\n"
/* load r0, r2 */ /* mul r0, with w0, w1, w2 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, " "vld1.32 {d24-d27}, [%[ptr_out1]]! @ load outr1\n"
"8 float\n" "vmla.f32 q8, q5, d0[0] @ w0 * inr00\n"
"vld1.32 {d8}, [%[r0]] @ load r0, " "vld1.32 {d28-d31}, [%[ptr_out1]] @ load outr1\n"
"9th float\n" "vmla.f32 q9, q5, d1[0] @ w0 * inr02\n"
"vmla.f32 q10, q5, d2[0] @ w0 * inr04\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 " "vmla.f32 q11, q5, d3[0] @ w0 * inr06\n"
"- 32, to start address\n" "vld1.32 {d4-d7}, [%[r2]]! @ load r2\n"
"vmla.f32 q8, q6, d0[1] @ w1 * inr01\n"
/* main loop */ "vmla.f32 q9, q6, d1[1] @ w1 * inr03\n"
"0: @ main " "vmla.f32 q10, q6, d2[1] @ w1 * inr05\n"
"loop\n" "vmla.f32 q11, q6, d3[1] @ w1 * inr07\n"
/* mul r0, with w0, w1, w2 */ "vld1.32 {d9}, [%[r2]] @ load r2, 9th float\n"
"vld1.32 {d24-d27}, [%[ptr_out1]]! @ load " "vmla.f32 q8, q7, d1[0] @ w2 * inr02\n"
"outr1, w0, w1, c0~c3\n" "vmla.f32 q9, q7, d2[0] @ w2 * inr04\n"
"vmla.f32 q8, q5, d0[0] @ w0 * " "vmla.f32 q10, q7, d3[0] @ w2 * inr06\n"
"inr00\n" "vmla.f32 q11, q7, d8[0] @ w2 * inr08\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load " "sub %[r2], %[r2], #32 @ r2 - 32\n"
"outr1, w2, w3, c0~c3\n" /* mul r2, with w0, w1, w2 */
"vmla.f32 q9, q5, d1[0] @ w0 * " "vld1.32 {d0-d3}, [%[r1]]! @ load r1\n"
"inr02\n" "vmla.f32 q12, q5, d4[0] @ w0 * inr20\n"
"vmla.f32 q10, q5, d2[0] @ w0 * " "vmla.f32 q13, q5, d5[0] @ w0 * inr22\n"
"inr04\n" "vmla.f32 q14, q5, d6[0] @ w0 * inr24\n"
"vmla.f32 q11, q5, d3[0] @ w0 * " "vmla.f32 q15, q5, d7[0] @ w0 * inr26\n"
"inr06\n" "vld1.32 {d8}, [%[r1]] @ load r1, 9th float\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, " "vmla.f32 q12, q6, d4[1] @ w1 * inr21\n"
"8 float\n" "vmla.f32 q13, q6, d5[1] @ w1 * inr23\n"
"vmla.f32 q8, q6, d0[1] @ w1 * " "vmla.f32 q14, q6, d6[1] @ w1 * inr25\n"
"inr01\n" "vmla.f32 q15, q6, d7[1] @ w1 * inr27\n"
"vmla.f32 q9, q6, d1[1] @ w1 * " "vld1.32 {d10-d13}, [%[wc0]]! @ load w3, w4, to q5, q6\n"
"inr03\n" "vmla.f32 q12, q7, d5[0] @ w2 * inr22\n"
"vmla.f32 q10, q6, d2[1] @ w1 * " "vmla.f32 q13, q7, d6[0] @ w2 * inr24\n"
"inr05\n" "vmla.f32 q14, q7, d7[0] @ w2 * inr26\n"
"vmla.f32 q11, q6, d3[1] @ w1 * " "vmla.f32 q15, q7, d9[0] @ w2 * inr28\n"
"inr07\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w5, to q7\n"
"vld1.32 {d9}, [%[r2]] @ load r2, " /* mul r1, with w3, w4, w5 */
"9th float\n" "vmla.f32 q8, q5, d0[0] @ w3 * inr10\n"
"vmla.f32 q8, q7, d1[0] @ w2 * " "vmla.f32 q9, q5, d1[0] @ w3 * inr12\n"
"inr02\n" "vmla.f32 q10, q5, d2[0] @ w3 * inr14\n"
"vmla.f32 q9, q7, d2[0] @ w2 * " "vmla.f32 q11, q5, d3[0] @ w3 * inr16\n"
"inr04\n" "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 8 float\n"
"vmla.f32 q10, q7, d3[0] @ w2 * " "vmla.f32 q8, q6, d0[1] @ w4 * inr11\n"
"inr06\n" "vmla.f32 q9, q6, d1[1] @ w4 * inr13\n"
"vmla.f32 q11, q7, d8[0] @ w2 * " "vmla.f32 q10, q6, d2[1] @ w4 * inr15\n"
"inr08\n" "vmla.f32 q11, q6, d3[1] @ w4 * inr17\n"
"vld1.32 {d9}, [%[r3]] @ load r3, 9th float\n"
"sub %[r2], %[r2], #32 @ r2 - 32, " "vmla.f32 q8, q7, d1[0] @ w5 * inr12\n"
"load r2 twice\n" "vmla.f32 q9, q7, d2[0] @ w5 * inr14\n"
"vmla.f32 q10, q7, d3[0] @ w5 * inr16\n"
/* mul r2, with w0, w1, w2 */ "vmla.f32 q11, q7, d8[0] @ w5 * inr18\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, " "sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 - 32\n"
"8 float\n" /* mul r3, with w3, w4, w5 */
"vmla.f32 q12, q5, d4[0] @ w0 * " "vld1.32 {d0-d3}, [%[r2]]! @ load r2\n"
"inr20\n" "vmla.f32 q12, q5, d4[0] @ w3 * inr30\n"
"vmla.f32 q13, q5, d5[0] @ w0 * " "vmla.f32 q13, q5, d5[0] @ w3 * inr32\n"
"inr22\n" "vmla.f32 q14, q5, d6[0] @ w3 * inr34\n"
"vmla.f32 q14, q5, d6[0] @ w0 * " "vmla.f32 q15, q5, d7[0] @ w3 * inr36\n"
"inr24\n" "vld1.32 {d8}, [%[r2]] @ load r2, 9th float\n"
"vmla.f32 q15, q5, d7[0] @ w0 * " "vmla.f32 q12, q6, d4[1] @ w4 * inr31\n"
"inr26\n" "vmla.f32 q13, q6, d5[1] @ w4 * inr33\n"
"vld1.32 {d8}, [%[r1]] @ load r1, " "vmla.f32 q14, q6, d6[1] @ w4 * inr35\n"
"9th float\n" "vmla.f32 q15, q6, d7[1] @ w4 * inr37\n"
"vmla.f32 q12, q6, d4[1] @ w1 * " "vld1.32 {d10-d13}, [%[wc0]]! @ load w6, w7\n"
"inr21\n" "vmla.f32 q12, q7, d5[0] @ w5 * inr32\n"
"vmla.f32 q13, q6, d5[1] @ w1 * " "vmla.f32 q13, q7, d6[0] @ w5 * inr34\n"
"inr23\n" "vmla.f32 q14, q7, d7[0] @ w5 * inr36\n"
"vmla.f32 q14, q6, d6[1] @ w1 * " "vmla.f32 q15, q7, d9[0] @ w5 * inr38\n"
"inr25\n" "vld1.32 {d14-d15}, [%[wc0]]! @ load w8\n"
"vmla.f32 q15, q6, d7[1] @ w1 * " /* mul r2, with w6, w7, w8 */
"inr27\n" "vmla.f32 q8, q5, d0[0] @ w6 * inr20\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w3, " "vmla.f32 q9, q5, d1[0] @ w6 * inr22\n"
"w4, to q5, q6\n" "vmla.f32 q10, q5, d2[0] @ w6 * inr24\n"
"vmla.f32 q12, q7, d5[0] @ w2 * " "vmla.f32 q11, q5, d3[0] @ w6 * inr26\n"
"inr22\n" "vld1.32 {d4-d7}, [%[r4]]! @ load r4\n"
"vmla.f32 q13, q7, d6[0] @ w2 * " "vmla.f32 q8, q6, d0[1] @ w7 * inr21\n"
"inr24\n" "vmla.f32 q9, q6, d1[1] @ w7 * inr23\n"
"vmla.f32 q14, q7, d7[0] @ w2 * " "vmla.f32 q10, q6, d2[1] @ w7 * inr25\n"
"inr26\n" "vmla.f32 q11, q6, d3[1] @ w7 * inr27\n"
"vmla.f32 q15, q7, d9[0] @ w2 * " "vld1.32 {d9}, [%[r4]] @ load r4, 9th float\n"
"inr28\n" "vmla.f32 q8, q7, d1[0] @ w8 * inr22\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w5, " "vmla.f32 q9, q7, d2[0] @ w8 * inr24\n"
"to q7\n" "vmla.f32 q10, q7, d3[0] @ w8 * inr26\n"
"vmla.f32 q11, q7, d8[0] @ w8 * inr28\n"
/* mul r1, with w3, w4, w5 */ "sub %[wc0], %[wc0], #144 @ wc0 - 144\n"
"vmla.f32 q8, q5, d0[0] @ w3 * " /* mul r4, with w6, w7, w8 */
"inr10\n" "vld1.32 {d0-d3}, [%[r0]]! @ load r0\n"
"vmla.f32 q9, q5, d1[0] @ w3 * " "vmla.f32 q12, q5, d4[0] @ w3 * inr40\n"
"inr12\n" "vst1.32 {d16-d19}, [%[ptr_out0]]! @ save r00, r01\n"
"vmla.f32 q10, q5, d2[0] @ w3 * " "vmla.f32 q13, q5, d5[0] @ w3 * inr42\n"
"inr14\n" "vst1.32 {d20-d23}, [%[ptr_out0]]! @ save r02, r03\n"
"vmla.f32 q11, q5, d3[0] @ w3 * " "vmla.f32 q14, q5, d6[0] @ w3 * inr44\n"
"inr16\n" "vmla.f32 q15, q5, d7[0] @ w3 * inr46\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, " "vld1.32 {d8}, [%[r0]] @ load r0, 9th float\n"
"8 float\n" "vmla.f32 q12, q6, d4[1] @ w4 * inr41\n"
"vmla.f32 q8, q6, d0[1] @ w4 * " "vmla.f32 q13, q6, d5[1] @ w4 * inr43\n"
"inr11\n" "vmla.f32 q14, q6, d6[1] @ w4 * inr45\n"
"vmla.f32 q9, q6, d1[1] @ w4 * " "vmla.f32 q15, q6, d7[1] @ w4 * inr47\n"
"inr13\n" "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1\n"
"vmla.f32 q10, q6, d2[1] @ w4 * " "vmla.f32 q12, q7, d5[0] @ w5 * inr42\n"
"inr15\n" "vmla.f32 q13, q7, d6[0] @ w5 * inr44\n"
"vmla.f32 q11, q6, d3[1] @ w4 * " "vmla.f32 q14, q7, d7[0] @ w5 * inr46\n"
"inr17\n" "vmla.f32 q15, q7, d9[0] @ w5 * inr48\n"
"vld1.32 {d9}, [%[r3]] @ load r3, " "vld1.32 {d14-d15}, [%[wc0]]! @ load w2\n"
"9th float\n" "vst1.32 {d24-d27}, [%[ptr_out1]]! @ save r10, r11\n"
"vmla.f32 q8, q7, d1[0] @ w5 * " "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save r12, r13\n"
"inr12\n" "vld1.32 {d16-d19}, [%[ptr_out0]]! @ load outr0\n"
"vmla.f32 q9, q7, d2[0] @ w5 * " "vld1.32 {d20-d23}, [%[ptr_out0]] @ load outr0\n"
"inr14\n" "sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 - 32\n"
"vmla.f32 q10, q7, d3[0] @ w5 * " "subs %[cnt], #1 @ loop count--\n"
"inr16\n" "bne 0b @ jump to main loop\n"
"vmla.f32 q11, q7, d8[0] @ w5 * " : [cnt] "+r"(cnt),
"inr18\n" [r0] "+r"(r0),[r1] "+r"(r1),
[r2] "+r"(r2),[r3] "+r"(r3),
"sub %[ptr_out1], %[ptr_out1], #32 @ ptr_out1 " [r4] "+r"(r4),
"- 32, to start address\n" [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
/* mul r3, with w3, w4, w5 */ [wc0] "+r"(wc0)
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, " :
"8 float\n" : "cc","memory","q0","q1","q2","q3","q4",
"vmla.f32 q12, q5, d4[0] @ w3 * " "q5","q6","q7","q8","q9","q10",
"inr30\n" "q11","q12","q13","q14","q15"
"vmla.f32 q13, q5, d5[0] @ w3 * " );
"inr32\n" // clang-format on
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr34\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr36\n"
"vld1.32 {d8}, [%[r2]] @ load r2, "
"9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr31\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr33\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr35\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr37\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w6, "
"w7, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr32\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr34\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr36\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr38\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w8, "
"to q7\n"
/* mul r2, with w6, w7, w8 */
"vmla.f32 q8, q5, d0[0] @ w6 * "
"inr20\n"
"vmla.f32 q9, q5, d1[0] @ w6 * "
"inr22\n"
"vmla.f32 q10, q5, d2[0] @ w6 * "
"inr24\n"
"vmla.f32 q11, q5, d3[0] @ w6 * "
"inr26\n"
"vld1.32 {d4-d7}, [%[r4]]! @ load r4, "
"8 float\n"
"vmla.f32 q8, q6, d0[1] @ w7 * "
"inr21\n"
"vmla.f32 q9, q6, d1[1] @ w7 * "
"inr23\n"
"vmla.f32 q10, q6, d2[1] @ w7 * "
"inr25\n"
"vmla.f32 q11, q6, d3[1] @ w7 * "
"inr27\n"
"vld1.32 {d9}, [%[r4]] @ load r4, "
"9th float\n"
"vmla.f32 q8, q7, d1[0] @ w8 * "
"inr22\n"
"vmla.f32 q9, q7, d2[0] @ w8 * "
"inr24\n"
"vmla.f32 q10, q7, d3[0] @ w8 * "
"inr26\n"
"vmla.f32 q11, q7, d8[0] @ w8 * "
"inr28\n"
"sub %[wc0], %[wc0], #144 @ wc0 - "
"144 to start address\n"
/* mul r4, with w6, w7, w8 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, "
"8 float\n"
"vmla.f32 q12, q5, d4[0] @ w3 * "
"inr40\n"
"vst1.32 {d16-d19}, [%[ptr_out0]]! @ save "
"r00, r01, c0~c3\n"
"vmla.f32 q13, q5, d5[0] @ w3 * "
"inr42\n"
"vst1.32 {d20-d23}, [%[ptr_out0]]! @ save "
"r02, r03, c0~c3\n"
"vmla.f32 q14, q5, d6[0] @ w3 * "
"inr44\n"
"vmla.f32 q15, q5, d7[0] @ w3 * "
"inr46\n"
"vld1.32 {d8}, [%[r0]] @ load "
"r0, 9th float\n"
"vmla.f32 q12, q6, d4[1] @ w4 * "
"inr41\n"
"vmla.f32 q13, q6, d5[1] @ w4 * "
"inr43\n"
"vmla.f32 q14, q6, d6[1] @ w4 * "
"inr45\n"
"vmla.f32 q15, q6, d7[1] @ w4 * "
"inr47\n"
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, "
"w1, to q5, q6\n"
"vmla.f32 q12, q7, d5[0] @ w5 * "
"inr42\n"
"vmla.f32 q13, q7, d6[0] @ w5 * "
"inr44\n"
"vmla.f32 q14, q7, d7[0] @ w5 * "
"inr46\n"
"vmla.f32 q15, q7, d9[0] @ w5 * "
"inr48\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, "
"to q7\n"
"vst1.32 {d24-d27}, [%[ptr_out1]]! @ save "
"r10, r11, c0~c3\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save "
"r12, r13, c0~c3\n"
"vld1.32 {d16-d19}, [%[ptr_out0]]! @ load "
"outr0, w0, w1, c0~c3\n"
"vld1.32 {d20-d23}, [%[ptr_out0]] @ load "
"outr0, w2, w3, c0~c3\n"
"sub %[ptr_out0], %[ptr_out0], #32 @ ptr_out0 "
"- 32, to start address\n"
"subs %[cnt], #1 @ loop "
"count--\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1),
[wc0] "+r"(wc0)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
...@@ -684,7 +499,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -684,7 +499,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
write_to_output_c4_fp32(pre_out, write_to_output_c4_fp32(pre_out,
dout_batch, dout_batch,
c, c,
c + hout_c_block, c + OUT_C_BLOCK,
h, h,
h + h_kernel, h + h_kernel,
0, 0,
...@@ -721,7 +536,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -721,7 +536,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
} }
fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel); fill_bias(pre_out, bias_ptr, 1, wout_round * h_kernel);
for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { for (int hk = 0; hk < h_kernel; hk += OUT_H_BLOCK) {
const float* wc0 = weight_c; const float* wc0 = weight_c;
const float* inr0 = block_inr0; const float* inr0 = block_inr0;
...@@ -755,158 +570,80 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -755,158 +570,80 @@ void conv_3x3s2_direct_fp32(const float* i_data,
const float* r4 = inr4; const float* r4 = inr4;
int cnt = w_loop; int cnt = w_loop;
// clang-format off
asm volatile( asm volatile(
"ldr q21, [%[ptr_out0]] \n" /* load outr00, "ldr q21, [%[ptr_out0]]\n" /* load outr00-outr03*/
outr01, "ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/
outr02, "ldr d10, [%[r0]]\n"/* load input r0, 9th element*/
outr03*/ "ld2 {v4.4s, v5.4s}, [%[r2]], #32\n" /* load input r2*/
"ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/ "2:\n" /* main loop*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th
element*/
"2: \n" /* main loop*/
/* r0, r2, mul w0, get out r0, r1 */ /* r0, r2, mul w0, get out r0, r1 */
"ldr q22, [%[ptr_out1]] \n" /* load outr10, outr11, "ldr q22, [%[ptr_out1]]\n" /* load outr10 - outr13*/
outr12, outr13*/ "fmla v21.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0*/
"fmla v22.4s , %[w0].4s, v4.4s\n" /* outr1 = w0 * r2*/
"fmla v21.4s , %[w0].4s, v0.4s \n" /* outr0 = w0 * r0[0, 2, "ld2 {v2.4s, v3.4s}, [%[r1]], #32\n" /* load input r1*/
4, 6]*/
"fmla v22.4s , %[w0].4s, v4.4s \n" /* outr1 = w0 * r2[0, 2,
4, 6]*/
"ld2 {v2.4s, v3.4s}, [%[r1]], #32 \n" /* load input r1*/
/* r2 mul w6, get out r0*/ /* r2 mul w6, get out r0*/
"fmla v21.4s , %[w6].4s, v4.4s \n" /* outr0 = w6 * r2[0, 2, "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr0 = w6 * r2*/
4, 6]*/ "ldr d11, [%[r1]]\n" /* load input r1, 9th element*/
"ldr d11, [%[r1]] \n" /* load input r1, 9th
element*/
/* shift left 1 */ /* shift left 1 */
"ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/ "ext v15.16b, v0.16b, v10.16b, #4\n" /* shift left r0 1*/
"ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/ "ext v16.16b, v4.16b, v12.16b, #4\n" /* shift left r2 1*/
/* r0, r2, mul w1, get out r0, r1 */ /* r0, r2, mul w1, get out r0, r1 */
"fmla v21.4s , %[w1].4s, v1.4s \n" /* outr0 = w1 * r0[1, 3, "fmla v21.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0*/
5, 7]*/ "fmla v22.4s , %[w1].4s, v5.4s\n" /* outr1 = w1 * r2*/
"fmla v22.4s , %[w1].4s, v5.4s \n" /* outr1 = w1 * r2[1, 3, "ld2 {v6.4s, v7.4s}, [%[r3]], #32\n" /* load input r3*/
5, 7]*/
"ld2 {v6.4s, v7.4s}, [%[r3]], #32 \n" /* load input r3*/
/* r2 mul w7, get out r0 */ /* r2 mul w7, get out r0 */
"fmla v21.4s , %[w7].4s, v5.4s \n" /* outr00 = w7 * r2[1, "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr00 = w7 * r2*/
3, 5, 7]*/ "ldr d13, [%[r3]]\n" /* load input r3, 9th element*/
"ldr d13, [%[r3]] \n" /* load input r3, 9th
element*/
/* r0, r2, mul w2, get out r0, r1 */ /* r0, r2, mul w2, get out r0, r1 */
"fmla v21.4s , %[w2].4s, v15.4s \n" /* outr0 = w2 * r0[2, 4, "fmla v21.4s , %[w2].4s, v15.4s\n" /* outr0 = w2 * r0*/
6, 8]*/ "fmla v22.4s , %[w2].4s, v16.4s\n" /* outr1 = w2 * r2*/
"fmla v22.4s , %[w2].4s, v16.4s \n" /* outr1 = w2 * r2[2, 4, "ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/
6, 8]*/
"ld2 {v8.4s, v9.4s}, [%[r4]], #32 \n" /* load input r4*/
/* r2, mul w8, get out r0 */ /* r2, mul w8, get out r0 */
"fmla v21.4s , %[w8].4s, v16.4s \n" /* outr00 = w8 * r2[2, "fmla v21.4s , %[w8].4s, v16.4s\n" /* outr00 = w8 * r2*/
4, 6, 8]*/ "ldr d14, [%[r4]]\n" /* load input r4, 9th element*/
"ldr d14, [%[r4]] \n" /* load input r4, 9th
element*/
/* r1, r3, mul w3, get out r0, r1 */ /* r1, r3, mul w3, get out r0, r1 */
"fmla v21.4s , %[w3].4s, v2.4s \n" /* outr0 = w3 * r1[0, 2, "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr0 = w3 * r1*/
4, 6]*/ "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr1 = w3 * r3*/
"fmla v22.4s , %[w3].4s, v6.4s \n" /* outr1 = w3 * r3[0, 2,
4, 6]*/
/* shift left 1 */ /* shift left 1 */
"ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/ "ext v15.16b, v2.16b, v11.16b, #4\n" /* shift left r1 1*/
"ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/ "ext v16.16b, v6.16b, v13.16b, #4\n" /* shift left r3 1*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32\n" /* load input r0*/
"ld2 {v0.4s, v1.4s}, [%[r0]], #32 \n" /* load input r0*/
/* r1, r3, mul w4, get out r0, r1 */ /* r1, r3, mul w4, get out r0, r1 */
"fmla v21.4s , %[w4].4s, v3.4s \n" /* outr0 = w4 * r1[1, 3, "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr0 = w4 * r1*/
5, 7]*/ "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr1 = w4 * r3*/
"fmla v22.4s , %[w4].4s, v7.4s \n" /* outr1 = w4 * r3[1, 3, "ldr d10, [%[r0]]\n" /* load input r0, 9th element*/
5, 7]*/
"ldr d10, [%[r0]] \n" /* load input r0, 9th
element*/
/* r1, r3, mul w5, get out r0, r1 */ /* r1, r3, mul w5, get out r0, r1 */
"fmla v21.4s , %[w5].4s, v15.4s \n" /* outr0 = w5 * r1[2]*/ "fmla v21.4s , %[w5].4s, v15.4s\n" /* outr0 = w5 * r1[2]*/
"fmla v22.4s , %[w5].4s, v16.4s \n" /* outr1 = w5 * r1[4]*/ "fmla v22.4s , %[w5].4s, v16.4s\n" /* outr1 = w5 * r1[4]*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/
"ld2 {v4.4s, v5.4s}, [%[r2]], #32 \n" /* load input r2*/ "ldr d12, [%[r2]]\n" /* load input r2, 9th element*/
"ldr d12, [%[r2]] \n" /* load input r2, 9th "str q21, [%[ptr_out0]], #16\n" /* save outr00, outr01*/
element*/
"str q21, [%[ptr_out0]], #16 \n" /* save outr00, outr01*/
/* r4, mul w6, get out r1 */ /* r4, mul w6, get out r1 */
"fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4[0, 2, "fmla v22.4s , %[w6].4s, v8.4s \n" /* outr1 = w6 * r4*/
4, 6]*/
"ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/ "ext v15.16b, v8.16b, v14.16b, #4\n" /* shift left r1 1*/
"ldr q21, [%[ptr_out0]] \n" /* load outr0*/ "ldr q21, [%[ptr_out0]] \n" /* load outr0*/
/* r4, mul w7, get out r1 */ /* r4, mul w7, get out r1 */
"fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4[1, 3, "fmla v22.4s , %[w7].4s, v9.4s \n" /* outr1 = w7 * r4*/
5, 7]*/
/* r4, mul w8, get out r1 */ /* r4, mul w8, get out r1 */
"fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4[2, 4, "fmla v22.4s , %[w8].4s, v15.4s \n" /* outr1 = w8 * r4*/
6, 8]*/
"subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/ "subs %w[cnt], %w[cnt], #1 \n" /*loop count -1*/
"str q22, [%[ptr_out1]], #16 \n" /* save outr1*/ "str q22, [%[ptr_out1]], #16 \n" /* save outr1*/
"bne 2b \n" /* jump to main loop*/ "bne 2b \n" /* jump to main loop*/
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[r0] "+r"(r0), [r0] "+r"(r0),[r1] "+r"(r1),
[r1] "+r"(r1), [r2] "+r"(r2),[r3] "+r"(r3),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4), [r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0), [ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1) [ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), : [w0] "w"(w0),[w1] "w"(w1),[w2] "w"(w2),
[w1] "w"(w1), [w3] "w"(w3),[w4] "w"(w4),[w5] "w"(w5),
[w2] "w"(w2), [w6] "w"(w6),[w7] "w"(w7),[w8] "w"(w8)
[w3] "w"(w3), : "cc","memory","v0","v1","v2","v3",
[w4] "w"(w4), "v4","v5","v6","v7","v8","v9","v10","v11",
[w5] "w"(w5), "v12","v13","v14","v15","v16","v21","v22");
[w6] "w"(w6), // clang-format on
[w7] "w"(w7),
[w8] "w"(w8)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v21",
"v22");
wc0 += 36; wc0 += 36;
inr0 += win_round; inr0 += win_round;
inr1 += win_round; inr1 += win_round;
...@@ -944,184 +681,92 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -944,184 +681,92 @@ void conv_3x3s2_direct_fp32(const float* i_data,
int cnt = w_loop / 2; int cnt = w_loop / 2;
if (cnt > 0) { if (cnt > 0) {
// clang-format off
asm volatile( asm volatile(
/* main loop */ /* main loop */
"0: @ " "0: @ main loop\n"
"main loop\n" "vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, or01\n"
"vld1.32 {d24-d27}, [%[ptr_out0]] @ load or00, " "vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, or11\n"
"or01\n" "vld2.32 {d6-d9}, [%[r2]]! @ load r2\n"
"vld1.32 {d28-d31}, [%[ptr_out1]] @ load or10, " "vld2.32 {d10-d13}, [%[r2]]! @ load r2\n"
"or11\n" "vld1.32 {d22}, [%[r2]] @ load 16th float\n"
"vld2.32 {d6-d9}, [%[r2]]! @ load r2, 8 " /* r2 * w2, r2 * w0, get or0, or1 */
"float, interleave\n" "vmla.f32 q12, q4, %e[w2][1] @ w21 * r2\n"
"vld2.32 {d10-d13}, [%[r2]]! @ load r2, 8 " "vmla.f32 q13, q6, %e[w2][1] @ w21 * r2\n "
"float, interleave\n" "vld2.32 {d14-d17}, [%[r0]]! @ load r0\n"
"vld1.32 {d22}, [%[r2]] @ load 16th " "vmla.f32 q14, q4, %e[w0][1] @ w01 * r2\n"
"float\n" "vmla.f32 q15, q6, %e[w0][1] @ w01 * r2\n"
"vext.32 q4, q3, q5, #1 @ r2, shift left 1\n"
/* r2 * w2, r2 * w0, get or0, or1 */ "vext.32 q6, q5, q11, #1 @ r2, shift left 1\n"
"vmla.f32 q12, q4, %e[w2][1] @ w21 * r2, " "vmla.f32 q12, q3, %e[w2][0] @ w20 * r2\n"
"1, 3, 5, 7\n" "vmla.f32 q13, q5, %e[w2][0] @ w20 * r2\n"
"vmla.f32 q13, q6, %e[w2][1] @ w21 * r2, " "vld2.32 {d18-d21}, [%[r0]]! @ load r0\n"
"9, 11, 13, 15\n" "vmla.f32 q14, q3, %e[w0][0] @ w00 * r2\n"
"vld2.32 {d14-d17}, [%[r0]]! @ load r0, 8 " "vmla.f32 q15, q5, %e[w0][0] @ w00 * r2\n"
"float, interleave\n" "vld1.32 {d22}, [%[r0]] @ load 16th float\n"
"vmla.f32 q14, q4, %e[w0][1] @ w01 * r2, " "vmla.f32 q12, q4, %f[w2][0] @ w22 * r2\n"
"1, 3, 5, 7\n" "vmla.f32 q14, q4, %f[w0][0] @ w02 * r2\n"
"vmla.f32 q15, q6, %e[w0][1] @ w01 * r2, " "vld2.32 {d6-d9}, [%[r3]]! @ load r3\n"
"9, 11, 13, 15\n" "vmla.f32 q13, q6, %f[w2][0] @ w22 * r2\n"
"vmla.f32 q15, q6, %f[w0][0] @ w02 * r2\n"
"vext.32 q4, q3, q5, #1 @ r2, shift " "vld2.32 {d10-d13}, [%[r3]]! @ load r3\n"
"left 1, get 2, 4, 6, 8\n" /* r0 * w0, get or0, r3 * w1, get or1*/
"vext.32 q6, q5, q11, #1 @ r2, shift " "vmla.f32 q12, q8, %e[w0][1] @ w01 * r0\n"
"left 1, get 10, 12, 14, 16\n" "vmla.f32 q13, q10, %e[w0][1] @ w01 * r0\n"
"vext.32 q8, q7, q9, #1 @ r0, shift left 1\n"
"vmla.f32 q12, q3, %e[w2][0] @ w20 * r2, " "vext.32 q10, q9, q11, #1 @ r0, shift left 1\n"
"0, 2, 4, 6\n" "vld1.32 {d22}, [%[r3]] @ load 16th float\n"
"vmla.f32 q13, q5, %e[w2][0] @ w20 * r2, " "vmla.f32 q14, q4, %e[w1][1] @ w11 * r3\n"
"8, 10, 12, 14\n" "vmla.f32 q15, q6, %e[w1][1] @ w11 * r3\n"
"vld2.32 {d18-d21}, [%[r0]]! @ load r0, 8 " "vmla.f32 q12, q7, %e[w0][0] @ w00 * r0\n"
"float, interleave\n" "vmla.f32 q13, q9, %e[w0][0] @ w00 * r0\n"
"vmla.f32 q14, q3, %e[w0][0] @ w00 * r2, " "vext.32 q4, q3, q5, #1 @ r3, shift left 1\n"
"0, 2, 4, 6\n" "vext.32 q6, q5, q11, #1 @ r3, shift left 1\n"
"vmla.f32 q15, q5, %e[w0][0] @ w00 * r2, " "vmla.f32 q14, q3, %e[w1][0] @ w10 * r3\n"
"8, 10, 12, 14\n" "vmla.f32 q15, q5, %e[w1][0] @ w10 * r3\n"
"vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, "
"vld1.32 {d22}, [%[r0]] @ load 16th " "2, 4, 6, 8\n"
"float\n" "vld2.32 {d14-d17}, [%[r1]]! @ load r1\n"
"vmla.f32 q13, q10,%f[w0][0] @ w02 * r0\n"
"vmla.f32 q12, q4, %f[w2][0] @ w22 * r2, " "vld2.32 {d18-d21}, [%[r1]]! @ load r1\n"
"2, 4, 6, 8\n" "vmla.f32 q14, q4, %f[w1][0] @ w12 * r3\n"
"vmla.f32 q14, q4, %f[w0][0] @ w02 * r2, " "vld2.32 {d6-d9}, [%[r4]]! @ load r4\n"
"2, 4, 6, 8\n" "vmla.f32 q15, q6, %f[w1][0] @ w12 * r3\n"
"vld2.32 {d6-d9}, [%[r3]]! @ load r3, 8 " "vld2.32 {d10-d13}, [%[r4]]! @ load r4\n"
"float, interleave\n" "vld1.32 {d22}, [%[r1]] @ load 16th float\n"
"vmla.f32 q13, q6, %f[w2][0] @ w22 * r2, " /* r1 * w1, get or0, r4 * w2, get or1 */
"10, 12, 14, 16\n" "vmla.f32 q12, q8, %e[w1][1] @ w11 * r1\n"
"vmla.f32 q15, q6, %f[w0][0] @ w02 * r2, " "vmla.f32 q13, q10, %e[w1][1] @ w11 * r1\n"
"10, 12, 14, 16\n" "vext.32 q8, q7, q9, #1 @ r1, shift left 1\n"
"vld2.32 {d10-d13}, [%[r3]]! @ load r3, 8 " "vext.32 q10, q9, q11, #1 @ r1, shift left 1\n"
"float, interleave\n" "vmla.f32 q14, q4, %e[w2][1] @ w21 * r4\n"
"vmla.f32 q15, q6, %e[w2][1] @ w21 * r4\n"
/* r0 * w0, get or0, r3 * w1, get or1*/ "vld1.32 {d22}, [%[r4]] @ load 16th float\n"
"vmla.f32 q12, q8, %e[w0][1] @ w01 * r0, " "vmla.f32 q12, q7, %e[w1][0] @ w10 * r1\n"
"1, 3, 5, 7\n" "vmla.f32 q13, q9, %e[w1][0] @ w10 * r1\n"
"vmla.f32 q13, q10, %e[w0][1] @ w01 * r0, " "vext.32 q4, q3, q5, #1 @ r1, shift left 1\n"
"9, 11, 13, 15\n" "vext.32 q6, q5, q11, #1 @ r1, shift left 1\n"
"vext.32 q8, q7, q9, #1 @ r0, shift " "vmla.f32 q14, q3, %e[w2][0] @ w20 * r4\n"
"left 1, get 2, 4, 6, 8\n" "vmla.f32 q15, q5, %e[w2][0] @ w20 * r4\n"
"vext.32 q10, q9, q11, #1 @ r0, shift " "vmla.f32 q12, q8, %f[w1][0] @ w12 * r1\n"
"left 1, get 10, 12, 14, 16\n" "vmla.f32 q13, q10, %f[w1][0] @ w12 * r1\n"
"vld1.32 {d22}, [%[r3]] @ load 16th " "vmla.f32 q14, q4, %f[w2][0] @ w22 * r4\n"
"float\n" "vmla.f32 q15, q6, %f[w2][0] @ w22 * r4\n"
"vmla.f32 q14, q4, %e[w1][1] @ w11 * r3, " "vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n"
"1, 3, 5, 7\n" "vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n"
"vmla.f32 q15, q6, %e[w1][1] @ w11 * r3, " "subs %[cnt], #1 @ loop count -1\n"
"9, 11, 13, 15\n" "bne 0b @ jump to main loop\n"
: [cnt] "+r"(cnt),
"vmla.f32 q12, q7, %e[w0][0] @ w00 * r0, " [r0] "+r"(r0),[r1] "+r"(r1),[r2] "+r"(r2),
"0, 2, 4, 6\n" [r3] "+r"(r3),[r4] "+r"(r4),
"vmla.f32 q13, q9, %e[w0][0] @ w00 * r0, " [ptr_out0] "+r"(ptr_out0),
"8, 10, 12, 14\n" [ptr_out1] "+r"(ptr_out1)
"vext.32 q4, q3, q5, #1 @ r3, shift " : [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
"left 1, get 2, 4, 6, 8\n" : "cc","memory","q3","q4",
"vext.32 q6, q5, q11, #1 @ r3, shift " "q5","q6","q7","q8","q9","q10",
"left 1, get 10, 12, 14, 16\n" "q11","q12","q13","q14","q15"
"vmla.f32 q14, q3, %e[w1][0] @ w10 * r3, " );
"0, 2, 4, 6\n" // clang-format on
"vmla.f32 q15, q5, %e[w1][0] @ w10 * r3, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %f[w0][0] @ w02 * r0, "
"2, 4, 6, 8\n"
"vld2.32 {d14-d17}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q13, q10,%f[w0][0] @ w02 * r0, "
"10, 12, 14, 16\n"
"vld2.32 {d18-d21}, [%[r1]]! @ load r1, 8 "
"float, interleave\n"
"vmla.f32 q14, q4, %f[w1][0] @ w12 * r3, "
"2, 4, 6, 8\n"
"vld2.32 {d6-d9}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vmla.f32 q15, q6, %f[w1][0] @ w12 * r3, "
"10, 12, 14, 16\n"
"vld2.32 {d10-d13}, [%[r4]]! @ load r4, 8 "
"float, interleave\n"
"vld1.32 {d22}, [%[r1]] @ load 16th "
"float\n"
/* r1 * w1, get or0, r4 * w2, get or1 */
"vmla.f32 q12, q8, %e[w1][1] @ w11 * r1, "
"1, 3, 5, 7\n"
"vmla.f32 q13, q10, %e[w1][1] @ w11 * r1, "
"9, 11, 13, 15\n"
"vext.32 q8, q7, q9, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q10, q9, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q4, %e[w2][1] @ w21 * r4, "
"1, 3, 5, 7\n"
"vmla.f32 q15, q6, %e[w2][1] @ w21 * r4, "
"9, 11, 13, 15\n"
"vld1.32 {d22}, [%[r4]] @ load 16th "
"float\n"
"vmla.f32 q12, q7, %e[w1][0] @ w10 * r1, "
"0, 2, 4, 6\n"
"vmla.f32 q13, q9, %e[w1][0] @ w10 * r1, "
"8, 10, 12, 14\n"
"vext.32 q4, q3, q5, #1 @ r1, shift "
"left 1, get 2, 4, 6, 8\n"
"vext.32 q6, q5, q11, #1 @ r1, shift "
"left 1, get 10, 12, 14, 16\n"
"vmla.f32 q14, q3, %e[w2][0] @ w20 * r4, "
"0, 2, 4, 6\n"
"vmla.f32 q15, q5, %e[w2][0] @ w20 * r4, "
"8, 10, 12, 14\n"
"vmla.f32 q12, q8, %f[w1][0] @ w12 * r1, "
"2, 4, 6, 8\n"
"vmla.f32 q13, q10, %f[w1][0] @ w12 * r1, "
"10, 12, 14, 16\n"
"vmla.f32 q14, q4, %f[w2][0] @ w22 * r4, "
"2, 4, 6, 8\n"
"vmla.f32 q15, q6, %f[w2][0] @ w22 * r4, "
"10, 12, 14, 16\n"
"vst1.32 {d24-d27}, [%[ptr_out0]]! @ save or0\n"
"vst1.32 {d28-d31}, [%[ptr_out1]]! @ save or0\n"
"subs %[cnt], #1 @loop count "
"-1\n"
"bne 0b @ jump to "
"main loop\n"
: [cnt] "+r"(cnt),
[r0] "+r"(r0),
[r1] "+r"(r1),
[r2] "+r"(r2),
[r3] "+r"(r3),
[r4] "+r"(r4),
[ptr_out0] "+r"(ptr_out0),
[ptr_out1] "+r"(ptr_out1)
: [w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
//! deal with remain ow //! deal with remain ow
if (w_loop & 1) { if (w_loop & 1) {
......
...@@ -23,6 +23,9 @@ namespace lite { ...@@ -23,6 +23,9 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
/// conv 3x3s1
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_direct_fp32(const float* din, void conv_3x3s1_direct_fp32(const float* din,
float* dout, float* dout,
int num, int num,
...@@ -53,6 +56,9 @@ void conv_3x3s1_direct_int8(const int8_t* din, ...@@ -53,6 +56,9 @@ void conv_3x3s1_direct_int8(const int8_t* din,
ARMContext* ctx, ARMContext* ctx,
const float* scale); const float* scale);
/// conv3x3s2
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s2_direct_fp32(const float* din, void conv_3x3s2_direct_fp32(const float* din,
float* dout, float* dout,
int num, int num,
......
...@@ -1104,13 +1104,13 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) { ...@@ -1104,13 +1104,13 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) {
SetCacheInfo(0, 1, l1size); SetCacheInfo(0, 1, l1size);
SetCacheInfo(1, 1, l2size); SetCacheInfo(1, 1, l2size);
SetCacheInfo(2, 1, l3size); SetCacheInfo(2, 1, l3size);
workspace_.Resize({2 * (l1size + l2size)}); workspace_.Resize({llc_size()});
workspace_.mutable_data<int8_t>();
} }
bool DeviceInfo::ExtendWorkspace(int size) { bool DeviceInfo::ExtendWorkspace(size_t size) {
workspace_.Resize({size + llc_size()}); workspace_.Resize({size + llc_size()});
workspace_.mutable_data<int8_t>(); return workspace_.mutable_data<int8_t>() != nullptr;
return true;
} }
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
......
...@@ -73,7 +73,7 @@ class DeviceInfo { ...@@ -73,7 +73,7 @@ class DeviceInfo {
T* workspace_data() { T* workspace_data() {
return reinterpret_cast<T*>(workspace_.mutable_data<int8_t>()); return reinterpret_cast<T*>(workspace_.mutable_data<int8_t>());
} }
bool ExtendWorkspace(int size); bool ExtendWorkspace(size_t size);
private: private:
int core_num_; int core_num_;
......
...@@ -19,6 +19,25 @@ namespace lite { ...@@ -19,6 +19,25 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
auto& ctx = this->ctx_->template As<ARMContext>();
if (param.strides[0] == 2) {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx));
} else {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx));
}
}
template <> template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
...@@ -70,6 +89,9 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -70,6 +89,9 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
} }
} }
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::ReInitWhenNeeded() {}
template <> template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() { void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
...@@ -126,6 +148,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -126,6 +148,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
} }
} }
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::ReInitWhenNeeded() {}
template <> template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() { void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
......
...@@ -178,10 +178,12 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> { ...@@ -178,10 +178,12 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
w_scale_); w_scale_);
} }
virtual void ReInitWhenNeeded();
virtual void Run(); virtual void Run();
/// todo, support inplace weights transform /// todo, support inplace weights transform
protected: protected:
DDim last_shape_;
Tensor weights_; Tensor weights_;
Tensor bias_; Tensor bias_;
bool flag_trans_weights_{false}; bool flag_trans_weights_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册