diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index 6f056677378ad0499e0f2ce8b0dd56cee5d6a6ae..ddb3675faf7c67e2e4cc38e054c2e22f397d9de7 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -25,7 +25,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -40,7 +39,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -55,7 +53,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -70,7 +67,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -93,7 +89,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, const float *bias, int pad, bool flag_bias, - bool flag_relu, const operators::ActivationParam act_param, ARMContext *ctx) { if (pad == 0) { @@ -103,7 +98,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, @@ -118,7 +112,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, @@ -136,7 +129,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, @@ -151,7 +143,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, @@ -163,7 +154,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, } } } -// clang-format on + #ifdef __aarch64__ #define INIT_S1 \ "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ @@ -2318,7 +2309,6 @@ void act_switch_3x3s1p1(const float *din_ptr0, } } #endif -// clang-format on /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width > 4 @@ -2328,7 +2318,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -2857,7 +2846,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -3443,7 +3431,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -3579,129 +3566,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, } int cnt = tile_w; - /* - if (flag_relu) { - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) - "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 - "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 - "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) - "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) - MID_COMPUTE_S1 MID_RESULT_S1_RELU - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } else { - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) - "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 - "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 - "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) - "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) - MID_COMPUTE_S1 MID_RESULT_S1 - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } - */ act_switch_3x3s1p0(din_ptr0, din_ptr1, din_ptr2, @@ -3760,90 +3624,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, int cnt = tile_w; unsigned int *rmask_ptr = rmask; unsigned int *vmask_ptr = vmask; - /* - if (flag_relu) { - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1_RELU - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1 - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - }*/ act_switch_3x3s1p0(din_ptr0, din_ptr1, din_ptr2, @@ -4174,7 +3954,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const float *weights, const float *bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, @@ -4213,14 +3992,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - // #ifdef __aarch64__ - // float32x4_t wbias; - // if (flag_bias) { - // wbias = vdupq_n_f32(bias[i]); - // } else { - // wbias = vdupq_n_f32(0.f); - // } - // #endif // __aarch64__ float32x4_t wbias; float bias_val = 0.f; if (flag_bias) { @@ -4261,137 +4032,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, break; } } - /* - #ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [vzero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [vzero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } - #else - unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - #endif - */ unsigned int *vmask_ptr = vmask; act_switch_3x3s1p0_s(dr0, dr1, diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index fd54e214cf27e001e21efcf255b09113bbe12d19..7ca7536beb890ec419341776b9098340883753a5 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -836,7 +836,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; ctx->ExtendWorkspace(sizeof(float) * workspace_size); - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; /// get workspace diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index 602239a1fe1675c6eecb5b45a8e526ada98a56bb..a17d87b47dbe6cae95b7f6bc67f68a3e795d3231 100644 --- a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { @@ -24,13 +25,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p0_bias_s(float* dout, @@ -38,13 +39,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p1_bias(float* dout, @@ -52,13 +53,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2p1_bias_s(float* dout, @@ -66,13 +67,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2_fp32(const float* din, @@ -88,7 +89,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx) { if (pad == 0) { if (w_in > 7) { @@ -97,13 +98,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s2p0_bias_s(dout, @@ -111,13 +112,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } } @@ -128,13 +129,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s2p1_bias_s(dout, @@ -142,13 +143,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, weights, bias, flag_bias, - flag_relu, num, ch_in, h_in, w_in, h_out, w_out, + act_param, ctx); } } @@ -412,6 +413,83 @@ void conv_depthwise_3x3s2_fp32(const float* din, "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ \ "blt 1f \n" +#define LEFT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + "ld1 {v22.4s}, [%[six_ptr]] \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define LEFT_RESULT_S2_LEAKY_RELU \ + "ld1 {v22.4s}, [%[scale_ptr]] \n" \ + "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" #define MID_RESULT_S2_RELU \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ @@ -438,6 +516,58 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "bne 2b \n" +#define MID_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define MID_RESULT_S2_LEAKY_RELU \ + "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + #define RIGHT_RESULT_S2_RELU \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ @@ -456,6 +586,47 @@ void conv_depthwise_3x3s2_fp32(const float* din, "st1 {v17.4s}, [%[outptr1]], #16 \n" \ "4: \n" +#define RIGHT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define RIGHT_RESULT_S2_LEAKY_RELU \ + "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + #define COMPUTE_S_S2 \ "movi v9.4s, #0 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ @@ -500,7 +671,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "fmax v4.4s, v4.4s, v9.4s \n" \ \ "st1 {v4.4s}, [%[out]] \n" - #define COMPUTE_S_S2_P0 \ "movi v9.4s, #0 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ @@ -537,7 +707,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "fadd v4.4s, v4.4s, v16.4s \n" #define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" - #define RESULT_S_S2_P0_RELU \ "fmax v4.4s, v4.4s, v9.4s \n" \ "st1 {v4.4s}, [%[out]] \n" @@ -682,7 +851,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vst1.32 {d6-d7}, [%[outptr]]! \n" \ "cmp %[cnt], #1 \n" \ "blt 1f \n" - #define MID_RESULT_S2_RELU \ "vmax.f32 q3, q3, q9 @ relu \n" \ "subs %[cnt], #1 \n" \ @@ -739,7 +907,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vadd.f32 q3, q3, q5 @ add \n" #define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" - #define RESULT_S_S2_RELU \ "vmax.f32 q3, q3, q9 @ relu\n" \ \ @@ -787,13 +954,233 @@ void conv_depthwise_3x3s2_fp32(const float* din, "vadd.f32 q3, q3, q5 @ add \n" #define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" - #define RESULT_S_S2_P0_RELU \ "vmax.f32 q3, q3, q9 @ relu \n" \ "vst1.32 {d6-d7}, [%[out]] \n" - #endif - +#ifdef __aarch64__ +void act_switch_3x3s2p1(const float* din0_ptr, + const float* din1_ptr, + const float* din2_ptr, + const float* din3_ptr, + const float* din4_ptr, + float* doutr0_ptr, + float* doutr1_ptr, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + uint32x4_t wmask, + float32x4_t wbias, + float32x4_t vzero, + int cnt, + int cnt_remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(vsix), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU + MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [scale_ptr] "r"(vscale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } +} +#endif /** * \brief depthwise convolution kernel 3x3, stride 2 * w_in > 7 @@ -803,13 +1190,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -821,7 +1208,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, cnt_col++; size_right_remain -= 8; } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); int size_right_pad = w_out * 2 - w_in; @@ -935,96 +1322,24 @@ void conv_depthwise_3x3s2p1_bias(float* dout, doutr1_ptr = write_ptr; } int cnt = cnt_col; - if (flag_relu) { - asm volatile( - INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 - MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } else { - asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 - MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } + act_switch_3x3s2p1(din0_ptr, + din1_ptr, + din2_ptr, + din3_ptr, + din4_ptr, + doutr0_ptr, + doutr1_ptr, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + wmask, + wbias, + vzero, + cnt, + cnt_remain, + act_param); doutr0 = doutr0 + 2 * w_out; } #else @@ -1061,65 +1376,37 @@ void conv_depthwise_3x3s2p1_bias(float* dout, } int cnt = cnt_col; unsigned int* mask_ptr = dmask; - if (flag_relu) { - asm volatile( - INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 - MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 - MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + // do act + if (act_param.has_active) { + act_switch_process(doutr0, doutr0, w_out, &act_param); } doutr0 = doutr0 + w_out; } @@ -1136,13 +1423,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1198,108 +1485,59 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, unsigned int* mask_ptr = dmask; #ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S2 RESULT_S_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); #else - if (flag_relu) { - asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S2 RESULT_S_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif + // do act + if (act_param.has_active) { + act_switch_process(out_buf, out_buf, w_out, &act_param); + } for (int w = 0; w < w_out; ++w) { *dout_channel++ = out_buf[w]; } @@ -1310,6 +1548,269 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, } } +#ifdef __aarch64__ +void act_switch_3x3s2p0(const float* din0_ptr, + const float* din1_ptr, + const float* din2_ptr, + const float* din3_ptr, + const float* din4_ptr, + float* doutr0_ptr, + float* doutr1_ptr, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + uint32x4_t wmask, + float32x4_t wbias, + float32x4_t vzero, + int cnt, + int cnt_remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_RELU6 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU6 + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(vsix), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_LEAKY_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(vscale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + } +} +#endif /** * \brief depthwise convolution kernel 3x3, stride 2 */ @@ -1319,13 +1820,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1438,117 +1939,24 @@ void conv_depthwise_3x3s2p0_bias(float* dout, doutr1_ptr = write_ptr; } int cnt = tile_w; - if (flag_relu) { - asm volatile( - INIT_S2 - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2_RELU - "cmp %w[remain], #1 \n" - "blt 4f \n" RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2_RELU - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } else { - asm volatile( - INIT_S2 - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2 - "cmp %w[remain], #1 \n" - "blt 4f \n" RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2 - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - } + act_switch_3x3s2p0(din0_ptr, + din1_ptr, + din2_ptr, + din3_ptr, + din4_ptr, + doutr0_ptr, + doutr1_ptr, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + wmask, + wbias, + vzero, + cnt, + cnt_remain, + act_param); doutr0 = doutr0 + 2 * w_out; } #else @@ -1576,64 +1984,36 @@ void conv_depthwise_3x3s2p0_bias(float* dout, } int cnt = tile_w; unsigned int* mask_ptr = dmask; - if (flag_relu) { - asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU - RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 - RIGHT_RESULT_S2 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + if (act_param.has_active) { + act_switch_process(doutr0, doutr0, w_out, &act_param); } doutr0 = doutr0 + w_out; } @@ -1650,13 +2030,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, const float* weights, const float* bias, bool flag_bias, - bool flag_relu, const int num, const int ch_in, const int h_in, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext* ctx) { int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int out_pad_idx[4] = {0, 1, 2, 3}; @@ -1718,114 +2098,62 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, unsigned int* mask_ptr = dmask; #ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - } else { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - } + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + #else - if (flag_relu) { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif + if (act_param.has_active) { + act_switch_process(out_buf, out_buf, w_out, &act_param); + } for (int w = 0; w < w_out; ++w) { *dout_channel++ = out_buf[w]; } diff --git a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc index 9852c0f84eae8451ef795c95faddfc88e833bea8..8ecc21134017d6469071eb2adc4b2215877c8437 100644 --- a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc @@ -25,6 +25,511 @@ namespace paddle { namespace lite { namespace arm { namespace math { +#ifdef __aarch64__ +#define COMPUTE \ + "ldr q8, [%[bias]]\n" /* load bias */ \ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \ + "and v19.16b, v8.16b, v8.16b\n" \ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \ + "and v20.16b, v8.16b, v8.16b\n" \ + "ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ \ + "and v21.16b, v8.16b, v8.16b\n" \ + "ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ \ + "and v22.16b, v8.16b, v8.16b\n" \ + "ldr q8, [%[inr0]]\n" /* load input r0*/ \ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "ldr q8, [%[inr1]]\n" /* load input r1*/ \ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ \ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ \ + "fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ \ + "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ \ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ \ + "ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ \ + "ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ \ + "ldr q8, [%[inr2]]\n" /* load input r2*/ \ + "fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ \ + "fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ \ + "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ \ + "fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ \ + "fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ \ + "fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ \ + "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ \ + "fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ \ + "fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ \ + "fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ \ + "fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ \ + "fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ \ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ +#define RELU /* relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" +#define RELU6 /* relu6 */ \ + "fmin v19.4s, v19.4s, %[vsix].4s\n" \ + "fmin v20.4s, v20.4s, %[vsix].4s\n" \ + "fmin v21.4s, v21.4s, %[vsix].4s\n" \ + "fmin v22.4s, v22.4s, %[vsix].4s\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v19.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ + "str q22, [%[outc3]], #16\n" + +#else +#define COMPUTE \ + /* fill with bias */ \ + "vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ \ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ + "vand.i32 q12, q8, q8\n" \ + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ + "vand.i32 q13, q8, q8\n" \ + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ + "vand.i32 q14, q8, q8\n" \ + "vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ \ + "vand.i32 q15, q8, q8\n" \ + "vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ \ + "vmla.f32 q12, q9, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w0 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w0 * inr6\n" \ + "vmla.f32 q12, q10, q1 @ w1 * inr1\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ + "vmla.f32 q13, q10, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w1 * inr7\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ \ + "vmla.f32 q12, q11, q2 @ w2 * inr2\n" \ + "vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" \ + "vmla.f32 q13, q11, q4 @ w2 * inr4\n" \ + "vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" \ + "vmla.f32 q14, q11, q6 @ w2 * inr6\n" \ + "vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" \ + "vmla.f32 q15, q11, q8 @ w2 * inr8\n" /* mul r1 with w3, w4*/ \ + "vmla.f32 q12, q9, q0 @ w3 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w3 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w3 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w3 * inr6\n" \ + "vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ \ + "vmla.f32 q12, q10, q1 @ w4 * inr1\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ + "vmla.f32 q13, q10, q3 @ w4 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w4 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w4 * inr7\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ \ + "vmla.f32 q12, q11, q2 @ w5 * inr2\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ + "vmla.f32 q13, q11, q4 @ w5 * inr4\n" \ + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ + "vmla.f32 q14, q11, q6 @ w5 * inr6\n" \ + "vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" \ + "vmla.f32 q15, q11, q8 @ w5 * inr8\n" /* mul r2 with w6, w7*/ \ + "vmla.f32 q12, q9, q0 @ w6 * inr0\n" \ + "vmla.f32 q13, q9, q2 @ w6 * inr2\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ \ + "vmla.f32 q14, q9, q4 @ w6 * inr4\n" \ + "vmla.f32 q15, q9, q6 @ w6 * inr6\n" \ + "vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ \ + "vmla.f32 q12, q10, q1 @ w7 * inr1\n" \ + "vmla.f32 q13, q10, q3 @ w7 * inr3\n" \ + "vmla.f32 q14, q10, q5 @ w7 * inr5\n" \ + "vmla.f32 q15, q10, q7 @ w7 * inr7\n" \ + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \ + "vmla.f32 q12, q11, q2 @ w8 * inr2\n" \ + "vmla.f32 q13, q11, q4 @ w8 * inr4\n" \ + "vmla.f32 q14, q11, q6 @ w8 * inr6\n" \ + "vmla.f32 q15, q11, q8 @ w8 * inr8\n" /* transpose */ \ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ +#define RELU /* relu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f32 q12, q12, q0\n" \ + "vmax.f32 q13, q13, q0\n" \ + "vmax.f32 q14, q14, q0\n" \ + "vmax.f32 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f32 q12, q12, q1\n" \ + "vmin.f32 q13, q13, q1\n" \ + "vmin.f32 q14, q14, q1\n" \ + "vmin.f32 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f32 q3, q12, q1 @ mul \n" \ + "vmul.f32 q5, q13, q1 @ mul \n" \ + "vmul.f32 q7, q14, q1 @ mul \n" \ + "vmul.f32 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + +#endif + +void act_switch_3x3s2(const float* inr0, + const float* inr1, + const float* inr2, + float* outc0, + float* outc1, + float* outc2, + float* outc3, + const float* weight_c, + float* bias_local, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + float32x4_t w7, + float32x4_t w8, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(tmp); + float32x4_t vscale = vdupq_n_f32(ss); +#else + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local), + [vsix] "w"(vsix) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local), + [vscale] "w"(vscale) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [bias] "r"(bias_local) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} void conv_3x3s2_depthwise_fp32(const float* i_data, float* o_data, @@ -38,6 +543,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx) { auto paddings = *param.paddings; int threads = ctx->threads(); @@ -51,11 +557,9 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, const int win_round = ROUNDUP(win_ext, 4); const int hin_round = oh * 2 + 1; const int prein_size = win_round * hin_round * out_c_block; - auto workspace_size = - threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; + auto workspace_size = threads * prein_size + win_round + ow_round; ctx->ExtendWorkspace(sizeof(float) * workspace_size); - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; /// get workspace @@ -77,6 +581,8 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; + float32x4_t vzero = vdupq_n_f32(0.f); + for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -147,201 +653,47 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, outc2 = pre_out + 8; outc3 = pre_out + 12; } -// clang-format off #ifdef __aarch64__ - asm volatile( - "ldr q8, [%[bias]]\n" /* load bias */ - "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ - "and v19.16b, v8.16b, v8.16b\n" - "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ - "and v20.16b, v8.16b, v8.16b\n" - "ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ - "and v21.16b, v8.16b, v8.16b\n" - "ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ - "and v22.16b, v8.16b, v8.16b\n" - "ldr q8, [%[inr0]]\n" /* load input r0*/ - /* r0 mul w0-w2, get out */ - "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ - "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ - "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ - "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ - "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ - "ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ - "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ - "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ - "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ - "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ - "ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ - "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ - "ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ - "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ - "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ - "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ - "ldr q8, [%[inr1]]\n" /* load input r1*/ - /* r1, mul w3-w5, get out */ - "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ - "fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ - "fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ - "fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ - "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ - "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ - "fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ - "fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ - "fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ - "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ - "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ - "fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ - "ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ - "fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ - "ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ - "fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ - "ldr q8, [%[inr2]]\n" /* load input r2*/ - /* r2, mul w6-w8, get out r0, r1 */ - "fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ - "fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ - "fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ - "fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ - "fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ - "fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ - "fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ - "fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ - "fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ - "fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ - "fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ - "fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ - /* transpose */ - "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ - "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - /* relu */ - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - /* save result */ - "0:\n" - "str q19, [%[outc0]], #16\n" - "str q20, [%[outc1]], #16\n" - "str q21, [%[outc2]], #16\n" - "str q22, [%[outc3]], #16\n" - :[inr0] "+r"(inr0), [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [outc0]"+r"(outc0), [outc1]"+r"(outc1), - [outc2]"+r"(outc2), [outc3]"+r"(outc3) - :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), - [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), - [bias] "r" (bias_local), [flag_relu]"r"(flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v8", "v19","v20","v21","v22" - ); + act_switch_3x3s2(inr0, + inr1, + inr2, + outc0, + outc1, + outc2, + outc3, + weight_c, + bias_local, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + w7, + w8, + act_param); #else - asm volatile( - /* fill with bias */ - "vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ - /* load weights */ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ - "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ - "vand.i32 q12, q8, q8\n" - "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ - "vand.i32 q13, q8, q8\n" - "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ - "vand.i32 q14, q8, q8\n" - "vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ - "vand.i32 q15, q8, q8\n" - "vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ - /* mul r0 with w0, w1, w2 */ - "vmla.f32 q12, q9, q0 @ w0 * inr0\n" - "vmla.f32 q13, q9, q2 @ w0 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ - "vmla.f32 q14, q9, q4 @ w0 * inr4\n" - "vmla.f32 q15, q9, q6 @ w0 * inr6\n" - "vmla.f32 q12, q10, q1 @ w1 * inr1\n" - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" - "vmla.f32 q13, q10, q3 @ w1 * inr3\n" - "vmla.f32 q14, q10, q5 @ w1 * inr5\n" - "vmla.f32 q15, q10, q7 @ w1 * inr7\n" - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ - "vmla.f32 q12, q11, q2 @ w2 * inr2\n" - "vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" - "vmla.f32 q13, q11, q4 @ w2 * inr4\n" - "vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" - "vmla.f32 q14, q11, q6 @ w2 * inr6\n" - "vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" - "vmla.f32 q15, q11, q8 @ w2 * inr8\n" - /* mul r1 with w3, w4, w5 */ - "vmla.f32 q12, q9, q0 @ w3 * inr0\n" - "vmla.f32 q13, q9, q2 @ w3 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ - "vmla.f32 q14, q9, q4 @ w3 * inr4\n" - "vmla.f32 q15, q9, q6 @ w3 * inr6\n" - "vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ - "vmla.f32 q12, q10, q1 @ w4 * inr1\n" - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" - "vmla.f32 q13, q10, q3 @ w4 * inr3\n" - "vmla.f32 q14, q10, q5 @ w4 * inr5\n" - "vmla.f32 q15, q10, q7 @ w4 * inr7\n" - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ - "vmla.f32 q12, q11, q2 @ w5 * inr2\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" - "vmla.f32 q13, q11, q4 @ w5 * inr4\n" - "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" - "vmla.f32 q14, q11, q6 @ w5 * inr6\n" - "vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" - "vmla.f32 q15, q11, q8 @ w5 * inr8\n" - /* mul r2 with w6, w7, w8 */ - "vmla.f32 q12, q9, q0 @ w6 * inr0\n" - "vmla.f32 q13, q9, q2 @ w6 * inr2\n" - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ - "vmla.f32 q14, q9, q4 @ w6 * inr4\n" - "vmla.f32 q15, q9, q6 @ w6 * inr6\n" - "vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ - "vmla.f32 q12, q10, q1 @ w7 * inr1\n" - "vmla.f32 q13, q10, q3 @ w7 * inr3\n" - "vmla.f32 q14, q10, q5 @ w7 * inr5\n" - "vmla.f32 q15, q10, q7 @ w7 * inr7\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" - "vmla.f32 q12, q11, q2 @ w8 * inr2\n" - "vmla.f32 q13, q11, q4 @ w8 * inr4\n" - "vmla.f32 q14, q11, q6 @ w8 * inr6\n" - "vmla.f32 q15, q11, q8 @ w8 * inr8\n" - /* transpose */ - "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ - "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ - "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ - "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ - "cmp %[flag_relu], #0\n" - "beq 0f\n" /* skip relu*/ - "vmov.u32 q0, #0\n" - "vmax.f32 q12, q12, q0\n" - "vmax.f32 q13, q13, q0\n" - "vmax.f32 q14, q14, q0\n" - "vmax.f32 q15, q15, q0\n" - "0:\n" - "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ - "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ - "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ - "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ - :[r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [wc0] "+r" (weight_c), - [outc0]"+r"(outc0), [outc1]"+r"(outc1), - [outc2]"+r"(outc2), [outc3]"+r"(outc3) - :[bias] "r" (bias_local), - [flag_relu]"r"(flag_relu) - :"cc", "memory", - "q0","q1","q2","q3","q4","q5","q6","q7", - "q8", "q9","q10","q11","q12","q13","q14","q15" - ); -#endif // __arch64__ - // clang-format off + act_switch_3x3s2(inr0, + inr1, + inr2, + outc0, + outc1, + outc2, + outc3, + weight_c, + bias_local, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + act_param); +#endif if (flag_mask) { for (int i = 0; i < remain; ++i) { c0[i] = pre_out[i]; @@ -350,6 +702,13 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, c3[i] = pre_out[i + 12]; } } + inr0 += 32; + inr1 += 32; + inr2 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; } } } diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index c4fe965d0b17fa56d76812af14b40bddbc5b313a..e148bb6928c0ddea7486cf8b23cd75e8201318ce 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -2151,6 +2151,210 @@ inline void act_switch_c8_fp32(const float* din_ptr, } } +#ifdef __aarch64__ +#define LOAD_DATA \ + "1: \n" \ + "ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ +#define DO_RELU \ + "fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \ + "fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */ +#define DO_RELU6 \ + "fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ + "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ +#define DO_LEAKY_RELU \ + "cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ +#define DO_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ + "bne 1b \n" +#else +#define LOAD_DATA \ + "1: \n" \ + "vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \ + "vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" +#define DO_RELU \ + "vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \ + "vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n" +#define DO_RELU6 \ + "vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \ + "vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \ + "vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n" +#define DO_LEAKY_RELU \ + "vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \ + "vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \ + "vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \ + "vbif q3, q8, q7 @ choose \n" \ + "vbif q4, q10, q9 @ choose \n" \ + "vbif q5, q12, q11 @ choose \n" \ + "vbif q6, q13, q13 @ choose \n" +#define DO_STORE \ + "subs %[cnt], #1 \n" \ + "vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \ + "bne 1b \n" +#endif +/* +* Data do activation process +* Now support relu relu6 leakyrelu act +*/ +inline void act_switch_process(float* src, + float* dst, + int size, + const operators::ActivationParam* act_param) { + int cnt = size >> 4; + int remain = size % 16; + float32x4_t vzero = vdupq_n_f32(0.f); + if (act_param != nullptr && act_param->has_active) { + float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha); + if (cnt > 0) { + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + LOAD_DATA DO_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_RELU DO_RELU6 DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "memory", "cc", "v0", "v1", "v2", "v3"); +#else + asm volatile( + LOAD_DATA DO_RELU DO_RELU6 DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "memory", "cc", "q3", "q4", "q5", "q6"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile( + LOAD_DATA DO_LEAKY_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale) + : "memory", + "cc", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile( + LOAD_DATA DO_LEAKY_RELU DO_STORE + : [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt) + : [vzero] "w"(vzero), [vscale] "w"(vscale) + : "memory", + "cc", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } + // remain + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (int i = 0; i < remain; i++) { + *dst = *src >= 0.f ? *src : 0.f; + src++; + dst++; + } + case lite_api::ActivationType::kRelu6: + for (int i = 0; i < remain; i++) { + float tmp = *src >= 0.f ? *src : 0.f; + *dst = tmp <= act_param->Relu_clipped_coef + ? tmp + : act_param->Relu_clipped_coef; + src++; + dst++; + } + case lite_api::ActivationType::kLeakyRelu: + for (int i = 0; i < remain; i++) { + if (*src >= 0.f) { + *dst = *src; + } else { + *dst = *src * act_param->Leaky_relu_alpha; + } + src++; + dst++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } +} + /*wirte result in outputs * input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] */ diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index 503dab29b6c4f0b9d3ff30a89060e473194216a9..bb85e747747c880dd11bd16b57b4db0779ac3683 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -52,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s1_fp32(const float* din, @@ -67,7 +68,6 @@ void conv_depthwise_3x3s1_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, const operators::ActivationParam act_param, ARMContext* ctx); @@ -84,7 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, const float* bias, int pad, bool flag_bias, - bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx); template diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 642d1c2c1b964b9553e522d70a086531f1706420..9156cd162f55c47b5b8d6e005814b5b3e5009004 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -584,7 +584,6 @@ void conv_depthwise_3x3_fp32(const void* din, const int pad_w = paddings[2]; int stride = param.strides[1]; int pad = pad_w; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; bool pads_equal = ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); @@ -603,7 +602,6 @@ void conv_depthwise_3x3_fp32(const void* din, bias, pad, flag_bias, - flag_relu, act_param, ctx); } else { @@ -638,7 +636,7 @@ void conv_depthwise_3x3_fp32(const void* din, bias, pad, flag_bias, - flag_relu, + act_param, ctx); } else { conv_3x3s2_depthwise_fp32(reinterpret_cast(din), @@ -653,6 +651,7 @@ void conv_depthwise_3x3_fp32(const void* din, reinterpret_cast(weights), bias, param, + act_param, ctx); } } else { diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index d9c0ecb4fd8457782ac90850b8b6a002c7dfcffe..9ae52d1cb6a406dc8d1059ad97f3757dbc0a31fa 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size, return output_size; } -inline void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilations, - const std::vector& strides, - const std::string padding_algorithm, - const lite::DDim data_dims, - const lite::DDim& ksize) { +void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { // when padding_desc is "VALID" or "SAME" if (padding_algorithm == "SAME") { for (size_t i = 0; i < strides.size(); ++i) { diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 24848803fb7ea2139f87aa5b5f2119592dc00084..6e1c0bb3d45448cfe985e3cef9ba8a038c94fb7d 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -136,7 +136,13 @@ class ConvOpLite : public OpLite { mutable ConvParam param_; std::string padding_algorithm_{""}; }; - +// update padding dilation +void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize); } // namespace operators } // namespace lite } // namespace paddle