diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index c998ddc3a34c2f6194a5156b7d04b7a9db3fbcef..b4539db98c3ffb1a143c38dd3c4dd9e9924bd63e 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -25,6 +25,73 @@ namespace paddle { namespace lite { namespace arm { namespace math { +void conv_3x3s1_depthwise_fp32_bias(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_relu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_relu6(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); + +void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx); // clang-format off #ifdef __aarch64__ #define COMPUTE \ @@ -335,7 +402,6 @@ namespace math { "ldr r0, [%[outl]] @ load outc00 to r0\n" \ "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ - "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \ "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ @@ -406,7 +472,6 @@ namespace math { "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ - "ldr r5, [%[outl], #20] @ load outc11 to r5\n" \ "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ @@ -417,12 +482,13 @@ namespace math { "vst1.32 {d18-d19}, [r1] @ save outc10\n" \ "vst1.32 {d20-d21}, [r2] @ save outc20\n" \ "vst1.32 {d22-d23}, [r3] @ save outc30\n" \ + "ldr r0, [%[outl], #20] @ load outc11 to r5\n" \ + "ldr r1, [%[outl], #24] @ load outc21 to r0\n" \ + "ldr r2, [%[outl], #28] @ load outc31 to r1\n" \ "vst1.32 {d24-d25}, [r4] @ save outc01\n" \ - "vst1.32 {d26-d27}, [r5] @ save outc11\n" \ - "ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ - "ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ - "vst1.32 {d28-d29}, [r0] @ save outc21\n" \ - "vst1.32 {d30-d31}, [r1] @ save outc31\n" \ + "vst1.32 {d26-d27}, [r0] @ save outc11\n" \ + "vst1.32 {d28-d29}, [r1] @ save outc21\n" \ + "vst1.32 {d30-d31}, [r2] @ save outc31\n" \ "b 3f @ branch end\n" \ "2: \n" \ "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ @@ -436,291 +502,86 @@ namespace math { "3: \n" #endif // clang-format on -void act_switch_3x3s1(const float* inr0, - const float* inr1, - const float* inr2, - const float* inr3, - float* out0, - const float* weight_c, - float flag_mask, - void* outl_ptr, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - float32x4_t w7, - float32x4_t w8, - float32x4_t vbias, - const operators::ActivationParam act_param) { - bool has_active = act_param.has_active; - if (has_active) { +void conv_3x3s1_depthwise_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx) { + float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; + float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { switch (act_param.active_type) { case lite_api::ActivationType::kRelu: -#ifdef __aarch64__ - asm volatile(COMPUTE RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + conv_3x3s1_depthwise_fp32_relu(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; case lite_api::ActivationType::kRelu6: -#ifdef __aarch64__ - asm volatile(COMPUTE RELU RELU6 STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE RELU RELU6 STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + six_ptr[0] = act_param.Relu_clipped_coef; + six_ptr[1] = act_param.Relu_clipped_coef; + six_ptr[2] = act_param.Relu_clipped_coef; + six_ptr[3] = act_param.Relu_clipped_coef; + conv_3x3s1_depthwise_fp32_relu6(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; case lite_api::ActivationType::kLeakyRelu: -#ifdef __aarch64__ - asm volatile(COMPUTE LEAKY_RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG -#else - asm volatile(COMPUTE LEAKY_RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); -#endif -#endif + scale_ptr[0] = act_param.Leaky_relu_alpha; + scale_ptr[1] = act_param.Leaky_relu_alpha; + scale_ptr[2] = act_param.Leaky_relu_alpha; + scale_ptr[3] = act_param.Leaky_relu_alpha; + conv_3x3s1_depthwise_fp32_leakyRelu(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); break; default: LOG(FATAL) << "this act_type: " @@ -728,108 +589,289 @@ void act_switch_3x3s1(const float* inr0, << " fuse not support"; } } else { -#ifdef __aarch64__ - asm volatile(COMPUTE STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [out] "+r"(out0) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [w7] "w"(w7), - [w8] "w"(w8), - [vbias] "w"(vbias), - [outl] "r"(outl_ptr), - [flag_mask] "r"(flag_mask) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7"); -#else -#if 1 // def LITE_WITH_ARM_CLANG + conv_3x3s1_depthwise_fp32_bias(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + win, + weights, + bias, + relu_ptr, + six_ptr, + scale_ptr, + param, + ctx); + } +} + +void conv_3x3s1_depthwise_fp32_bias(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: "; + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; #else - asm volatile(COMPUTE STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [out0] "+r"(out0), - [wc0] "+r"(weight_c) - : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15", - "r0", - "r1", - "r2", - "r3", - "r4", - "r5"); + float* pre_din = ptr_write + ow_round; #endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); #endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } } } -void conv_3x3s1_depthwise_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, - const float* weights, - const float* bias, - const operators::ConvParam& param, - const operators::ActivationParam act_param, - ARMContext* ctx) { + +void conv_3x3s1_depthwise_fp32_relu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { int threads = ctx->threads(); auto paddings = *param.paddings; @@ -869,31 +911,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; - float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; - float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - if (act_param.has_active) { - switch (act_param.active_type) { - case lite_api::ActivationType::kRelu: - break; - case lite_api::ActivationType::kRelu6: - six_ptr[0] = act_param.Relu_clipped_coef; - six_ptr[1] = act_param.Relu_clipped_coef; - six_ptr[2] = act_param.Relu_clipped_coef; - six_ptr[3] = act_param.Relu_clipped_coef; - break; - case lite_api::ActivationType::kLeakyRelu: - scale_ptr[0] = act_param.Leaky_relu_alpha; - scale_ptr[1] = act_param.Leaky_relu_alpha; - scale_ptr[2] = act_param.Leaky_relu_alpha; - scale_ptr[3] = act_param.Leaky_relu_alpha; - break; - default: - LOG(FATAL) << "this act_type: " - << static_cast(act_param.active_type) - << " fuse not support"; - } - } for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -944,13 +961,13 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* inr3 = inr2 + row_len; if (c + out_c_block > oc) { switch (c + out_c_block - oc) { - case 3: + case 3: // outc10-outc30 is ptr_write and extra outc10 = ptr_write; outc11 = ptr_write; - case 2: + case 2: // outc20-outc30 is ptr_write and extra outc20 = ptr_write; outc21 = ptr_write; - case 1: + case 1: // outc30 is ptr_write and extra outc30 = ptr_write; outc31 = ptr_write; default: @@ -981,48 +998,86 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, bool flag_mask = (w == w_loop - 1) && flag_remain; float* out0 = pre_out; #ifdef __aarch64__ - act_switch_3x3s1(inr0, - inr1, - inr2, - inr3, - out0, - weight_c, - flag_mask, - outl_ptr, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - w7, - w8, - vbias, - act_param); + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); #else -#if 1 // def LITE_WITH_ARM_CLANG -#else - act_switch_3x3s1(inr0, - inr1, - inr2, - inr3, - out0, - weight_c, - flag_mask, - outl_ptr, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - vbias, - act_param); -#endif + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); #endif outl[0] += 4; outl[1] += 4; @@ -1032,10 +1087,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outl[5] += 4; outl[6] += 4; outl[7] += 4; - inr0 += 16; - inr1 += 16; - inr2 += 16; - inr3 += 16; if (flag_mask) { memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); @@ -1052,6 +1103,499 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, } } +void conv_3x3s1_depthwise_fp32_relu6(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); +#endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} + +void conv_3x3s1_depthwise_fp32_leakyRelu(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + float* relu_ptr, + float* six_ptr, + float* scale_ptr, + const operators::ConvParam& param, + ARMContext* ctx) { + int threads = ctx->threads(); + + auto paddings = *param.paddings; + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + + const int out_c_block = 4; + const int out_h_kernel = 2; + const int out_w_kernel = 4; + const int win_ext = ow + 2; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh + 2; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = + threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + float* ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + float pre_out[out_c_block * out_w_kernel * out_h_kernel]; // NOLINT + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 9; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + float32x4_t vbias = vld1q_f32(bias_local); +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t w5 = vld1q_f32(weight_c + 20); // w5, v28 + float32x4_t w6 = vld1q_f32(weight_c + 24); // w6, v29 + float32x4_t w7 = vld1q_f32(weight_c + 28); // w7, v30 + float32x4_t w8 = vld1q_f32(weight_c + 32); // w8, v31 +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc00 = dout_c00 + h * ow; + float* outc01 = outc00 + ow; + float* outc10 = outc00 + size_out_channel; + float* outc11 = outc10 + ow; + float* outc20 = outc10 + size_out_channel; + float* outc21 = outc20 + ow; + float* outc30 = outc20 + size_out_channel; + float* outc31 = outc30 + ow; + const float* inr0 = pre_din + h * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: // outc10-outc30 is ptr_write and extra + outc10 = ptr_write; + outc11 = ptr_write; + case 2: // outc20-outc30 is ptr_write and extra + outc20 = ptr_write; + outc21 = ptr_write; + case 1: // outc30 is ptr_write and extra + outc30 = ptr_write; + outc31 = ptr_write; + default: + break; + } + } + if (h + out_h_kernel > oh) { + outc01 = ptr_write; + outc11 = ptr_write; + outc21 = ptr_write; + outc31 = ptr_write; + } + + float* outl[] = {outc00, + outc10, + outc20, + outc30, + outc01, + outc11, + outc21, + outc31, + reinterpret_cast(bias_local), + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; + void* outl_ptr = reinterpret_cast(outl); + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + float* out0 = pre_out; +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4"); +#endif + outl[0] += 4; + outl[1] += 4; + outl[2] += 4; + outl[3] += 4; + outl[4] += 4; + outl[5] += 4; + outl[6] += 4; + outl[7] += 4; + if (flag_mask) { + memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); + memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); + memcpy(outl[2] - 4, pre_out + 8, remain * sizeof(float)); + memcpy(outl[3] - 4, pre_out + 12, remain * sizeof(float)); + memcpy(outl[4] - 4, pre_out + 16, remain * sizeof(float)); + memcpy(outl[5] - 4, pre_out + 20, remain * sizeof(float)); + memcpy(outl[6] - 4, pre_out + 24, remain * sizeof(float)); + memcpy(outl[7] - 4, pre_out + 28, remain * sizeof(float)); + } + } + } + } + } +} } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 2bad1f997f457429c013c11a1dce35eb43dc26da..fa2f85311b3ff4247d52505d750566ec80e47256 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din, int pad = pad_w; bool flag_bias = param.bias != nullptr; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); + bool ch_four = ch_in <= 4 * w_in; if (stride == 1) { - if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && (pad_h == pad_w) && + (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s1_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din, act_param, ctx); } else { -#ifdef __aarch64__ conv_3x3s1_depthwise_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din, param, act_param, ctx); -#else -#ifdef LITE_WITH_ARM_CLANG - LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, " - "this can run in basic"; -#else - conv_3x3s1_depthwise_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - param, - act_param, - ctx); -#endif -#endif } } else if (stride == 2) { - if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + if (ch_four && pads_less && pad_h == pad_w && + (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 54e67de5abbfc88f64a50b07335d2527d9738206..ba7837cfff312a15c9ec769ab4e8ac16d0945f4d 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -59,12 +59,6 @@ void ConvCompute::PrepareForRun() { bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2); bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2); -#ifdef __aarch64__ -#else - bool flag = - (stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true; - flag_dw_3x3 = flag_dw_3x3 && flag; -#endif bool flag_dw = flag_dw_3x3 || flag_dw_5x5; /// select conv impl diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 3558eb22fbd4863771bf2b6b2e62e51b75a1227e..e34da16acdd71b490b0a233513f525668618a288 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -28,11 +28,15 @@ void DepthwiseConv::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto w_dims = param.filter->dims(); auto kw = w_dims[3]; + auto channel = w_dims[0]; + auto hin = param.x->dims()[2]; + auto win = param.x->dims()[3]; auto paddings = *param.paddings; + bool ch_four = channel <= 4 * win; // select dw conv kernel if (kw == 3) { bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); - if (pads_less && paddings[0] == paddings[2] && + if (ch_four && pads_less && paddings[0] == paddings[2] && (paddings[0] == 0 || paddings[0] == 1)) { flag_trans_weights_ = false; } else { @@ -398,6 +402,14 @@ void DepthwiseConv::Run() { w_scale_.data()); } +#ifdef LITE_WITH_PROFILE +template <> +void DepthwiseConv:: + SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; +} +#endif + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/tests/math/sgemm_compute_test.cc b/lite/tests/math/sgemm_compute_test.cc index 9255e5cdced7698c80bc86e2747393149ab13236..11f39ccf57a8e51d456bc1f8f81ac47308dd6c20 100644 --- a/lite/tests/math/sgemm_compute_test.cc +++ b/lite/tests/math/sgemm_compute_test.cc @@ -39,6 +39,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); + #ifdef LITE_WITH_ARM // sgemm_test wiil not be operated except that it's // on arm backend.