提交 3455ab0a 编写于 作者: H HappyAngel 提交者: yiicy

[lite][arm] add conv+relu6/leakyRelu fusion (#2599)

上级 ec8353e8
......@@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......@@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......@@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......
......@@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float);
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
......@@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
const float* weight_remain_ptr = weights + c_round_down * w_stride;
#pragma omp parallel for num_threads(threads)
......@@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
}
}
......
......@@ -32,6 +32,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_s(float *dout,
......@@ -46,6 +47,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias(float *dout,
......@@ -60,6 +62,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_s(float *dout,
......@@ -74,6 +77,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx);
void conv_depthwise_3x3s1_fp32(const float *din,
......@@ -90,6 +94,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext *ctx) {
if (pad == 0) {
if (w_in > 5) {
......@@ -105,6 +110,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s(dout,
......@@ -119,6 +125,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in,
h_out,
w_out,
act_param,
ctx);
}
}
......@@ -136,6 +143,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s(dout,
......@@ -150,11 +158,12 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_in,
h_out,
w_out,
act_param,
ctx);
}
}
}
// clang-format on
#ifdef __aarch64__
#define INIT_S1 \
"PRFM PLDL1KEEP, [%[din_ptr0]] \n" \
......@@ -255,14 +264,12 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ /* r4 */ \
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/
#define LEFT_RESULT_S1 \
/* r4 */ \
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
......@@ -345,16 +352,15 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
#define MID_RESULT_S1 \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -411,30 +417,31 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define RIGHT_COMPUTE_S1 \
"3: \n" \
"movi v20.4s, #0 \n" \
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \
"ld1 {v22.4s}, [%[doutr0]] \n" \
"ld1 {v23.4s}, [%[doutr1]] \n" \
"ld1 {v24.4s}, [%[doutr2]] \n" \
"ld1 {v25.4s}, [%[doutr3]] \n" \
\
"bif v0.16b, %[vzero].16b, v18.16b \n" \
"bif v1.16b, %[vzero].16b, v19.16b \n" \
"bif v2.16b, %[vzero].16b, v18.16b \n" \
"bif v3.16b, %[vzero].16b, v19.16b \n" \
"bif v0.16b, v20.16b, v18.16b \n" \
"bif v1.16b, v20.16b, v19.16b \n" \
"bif v2.16b, v20.16b, v18.16b \n" \
"bif v3.16b, v20.16b, v19.16b \n" \
\
"bif v4.16b, %[vzero].16b, v18.16b \n" \
"bif v5.16b, %[vzero].16b, v19.16b \n" \
"bif v6.16b, %[vzero].16b, v18.16b \n" \
"bif v7.16b, %[vzero].16b, v19.16b \n" \
"bif v4.16b, v20.16b, v18.16b \n" \
"bif v5.16b, v20.16b, v19.16b \n" \
"bif v6.16b, v20.16b, v18.16b \n" \
"bif v7.16b, v20.16b, v19.16b \n" \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v8.16b, %[vzero].16b, v18.16b \n" \
"bif v9.16b, %[vzero].16b, v19.16b \n" \
"bif v10.16b, %[vzero].16b, v18.16b \n" \
"bif v11.16b, %[vzero].16b, v19.16b \n" \
"bif v8.16b, v20.16b, v18.16b \n" \
"bif v9.16b, v20.16b, v19.16b \n" \
"bif v10.16b, v20.16b, v18.16b \n" \
"bif v11.16b, v20.16b, v19.16b \n" \
\
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
......@@ -467,15 +474,13 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/
#define RIGHT_RESULT_S1 \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v12.16b, v22.16b, v18.16b \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -520,10 +525,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define LEFT_RESULT_S1_RELU \
/* r4 */ \
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
\
......@@ -570,14 +571,113 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"blt 3f \n"
#define LEFT_RESULT_S1_RELU6 \
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \
"fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \
\
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
\
"fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \
"cmp %w[cnt], #1 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"blt 3f \n"
#define LEFT_RESULT_S1_LEAKY_RELU \
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \
"bif v13.16b, v21.16b, v19.16b \n" /* choose*/ \
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \
\
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \
"cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
\
"bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \
"cmp %w[cnt], #1 \n" \
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"blt 3f \n"
#define MID_RESULT_S1_RELU \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
"movi v20.4s, #0 \n" \
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -598,7 +698,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -617,7 +717,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
/* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
"fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
......@@ -633,20 +733,157 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
"fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"bne 1b \n"
#define RIGHT_RESULT_S1_RELU \
#define MID_RESULT_S1_RELU6 \
"movi v20.4s, #0 \n" \
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \
\
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
/* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"bne 1b \n"
#define MID_RESULT_S1_LEAKY_RELU \
"movi v21.4s, #0 \n" \
"cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \
\
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v13.16b, v20.16b, v18.16b \n" /* choose*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
"bne 1b \n"
#define RIGHT_RESULT_S1_RELU \
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -664,7 +901,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -680,7 +917,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
"fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
......@@ -690,72 +927,184 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
"fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \
\
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define COMPUTE_S_S1 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s}, [%[din0]], #16\n" \
"ld1 {v1.4s}, [%[din1]], #16\n" \
"ld1 {v2.4s}, [%[din2]], #16\n" \
"ld1 {v3.4s}, [%[din3]], #16\n" \
\
"bif v0.16b, %[zero].16b, %[mask].16b\n" \
"bif v1.16b, %[zero].16b, %[mask].16b\n" \
"bif v2.16b, %[zero].16b, %[mask].16b\n" \
"bif v3.16b, %[zero].16b, %[mask].16b\n" \
\
"ext v4.16b, %[zero].16b, v0.16b, #12\n" \
"ext v5.16b, %[zero].16b, v1.16b, #12\n" \
"ext v6.16b, %[zero].16b, v2.16b, #12\n" \
"ext v7.16b, %[zero].16b, v3.16b, #12\n" \
\
"ext v8.16b, v0.16b, %[zero].16b, #4\n" \
"ext v9.16b, v1.16b, %[zero].16b, #4\n" \
"ext v10.16b, v2.16b, %[zero].16b, #4\n" \
"ext v11.16b, v3.16b, %[zero].16b, #4\n" \
\
"fmul v12.4s, v0.4s, %[wr0].s[1]\n" \
"fmul v13.4s, v1.4s, %[wr0].s[1]\n" \
\
"fmul v14.4s, v1.4s, %[wr1].s[1]\n" \
"fmul v15.4s, v2.4s, %[wr1].s[1]\n" \
\
"fmul v16.4s, v2.4s, %[wr2].s[1]\n" \
"fmul v17.4s, v3.4s, %[wr2].s[1]\n" \
\
"fmla v12.4s, v4.4s, %[wr0].s[0]\n" \
"fmla v13.4s, v5.4s, %[wr0].s[0]\n" \
\
"fmla v14.4s, v5.4s, %[wr1].s[0]\n" \
"fmla v15.4s, v6.4s, %[wr1].s[0]\n" \
\
"fmla v16.4s, v6.4s, %[wr2].s[0]\n" \
"fmla v17.4s, v7.4s, %[wr2].s[0]\n" \
\
"fmla v12.4s, v8.4s, %[wr0].s[2]\n" \
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
\
"fmla v14.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v15.4s, v10.4s, %[wr1].s[2]\n" \
\
"fmla v16.4s, v10.4s, %[wr2].s[2]\n" \
"fmla v17.4s, v11.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v14.4s\n" \
"fadd v12.4s, v12.4s, v16.4s\n" \
\
"fadd v13.4s, v13.4s, v15.4s\n" \
"fadd v13.4s, v13.4s, v17.4s\n" \
\
"fadd v12.4s, v12.4s, %[bias].4s\n" \
#define RIGHT_RESULT_S1_RELU6 \
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"bif v12.16b, v22.16b, v18.16b \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v10.4s, v20.s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"bif v14.16b, v24.16b, v18.16b \n" \
"fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define RIGHT_RESULT_S1_LEAKY_RELU \
"movi v1.4s, #0 \n" \
"cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v12.16b, v21.16b, v20.16b \n" /* choose*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"bif v12.16b, v22.16b, v18.16b \n" \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v13.16b, v21.16b, v20.16b \n" \
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
\
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \
"fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v14.16b, v21.16b, v20.16b \n" \
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"bif v14.16b, v24.16b, v18.16b \n" \
\
"cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \
"fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
"bif v15.16b, v21.16b, v20.16b \n" \
"bif v15.16b, v25.16b, v18.16b \n" \
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define COMPUTE_S_S1 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s}, [%[din0]], #16\n" \
"ld1 {v1.4s}, [%[din1]], #16\n" \
"ld1 {v2.4s}, [%[din2]], #16\n" \
"ld1 {v3.4s}, [%[din3]], #16\n" \
\
"bif v0.16b, %[vzero].16b, %[mask].16b\n" \
"bif v1.16b, %[vzero].16b, %[mask].16b\n" \
"bif v2.16b, %[vzero].16b, %[mask].16b\n" \
"bif v3.16b, %[vzero].16b, %[mask].16b\n" \
\
"ext v4.16b, %[vzero].16b, v0.16b, #12\n" \
"ext v5.16b, %[vzero].16b, v1.16b, #12\n" \
"ext v6.16b, %[vzero].16b, v2.16b, #12\n" \
"ext v7.16b, %[vzero].16b, v3.16b, #12\n" \
\
"ext v8.16b, v0.16b, %[vzero].16b, #4\n" \
"ext v9.16b, v1.16b, %[vzero].16b, #4\n" \
"ext v10.16b, v2.16b, %[vzero].16b, #4\n" \
"ext v11.16b, v3.16b, %[vzero].16b, #4\n" \
\
"fmul v12.4s, v0.4s, %[wr0].s[1]\n" \
"fmul v13.4s, v1.4s, %[wr0].s[1]\n" \
\
"fmul v14.4s, v1.4s, %[wr1].s[1]\n" \
"fmul v15.4s, v2.4s, %[wr1].s[1]\n" \
\
"fmul v16.4s, v2.4s, %[wr2].s[1]\n" \
"fmul v17.4s, v3.4s, %[wr2].s[1]\n" \
\
"fmla v12.4s, v4.4s, %[wr0].s[0]\n" \
"fmla v13.4s, v5.4s, %[wr0].s[0]\n" \
\
"fmla v14.4s, v5.4s, %[wr1].s[0]\n" \
"fmla v15.4s, v6.4s, %[wr1].s[0]\n" \
\
"fmla v16.4s, v6.4s, %[wr2].s[0]\n" \
"fmla v17.4s, v7.4s, %[wr2].s[0]\n" \
\
"fmla v12.4s, v8.4s, %[wr0].s[2]\n" \
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
\
"fmla v14.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v15.4s, v10.4s, %[wr1].s[2]\n" \
\
"fmla v16.4s, v10.4s, %[wr2].s[2]\n" \
"fmla v17.4s, v11.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v14.4s\n" \
"fadd v12.4s, v12.4s, v16.4s\n" \
\
"fadd v13.4s, v13.4s, v15.4s\n" \
"fadd v13.4s, v13.4s, v17.4s\n" \
\
"fadd v12.4s, v12.4s, %[bias].4s\n" \
"fadd v13.4s, v13.4s, %[bias].4s\n"
#define RESULT_S_S1 \
......@@ -765,16 +1114,42 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[zero].4s\n" \
"fmax v13.4s, v13.4s, %[zero].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
#define RESULT_S_S1_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[vzero].4s\n" \
"fmax v13.4s, v13.4s, %[vzero].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_RELU6 \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[vzero].4s\n" \
"fmax v13.4s, v13.4s, %[vzero].4s\n" \
\
"fmin v12.4s, v12.4s, %[vsix].4s\n" \
"fmin v13.4s, v13.4s, %[vsix].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_LEAKY_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\
"bif v12.16b, v20.16b, v18.16b \n" \
"bif v13.16b, v21.16b, v19.16b \n" \
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
......@@ -786,17 +1161,17 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ld1 {v4.4s, v5.4s}, [%[din2]]\n" \
"ld1 {v6.4s, v7.4s}, [%[din3]]\n" \
\
"bif v0.16b, %[zero].16b, %[mask1].16b\n" \
"bif v1.16b, %[zero].16b, %[mask2].16b\n" \
"bif v0.16b, %[vzero].16b, %[mask1].16b\n" \
"bif v1.16b, %[vzero].16b, %[mask2].16b\n" \
\
"bif v2.16b, %[zero].16b, %[mask1].16b\n" \
"bif v3.16b, %[zero].16b, %[mask2].16b\n" \
"bif v2.16b, %[vzero].16b, %[mask1].16b\n" \
"bif v3.16b, %[vzero].16b, %[mask2].16b\n" \
\
"bif v4.16b, %[zero].16b, %[mask1].16b\n" \
"bif v5.16b, %[zero].16b, %[mask2].16b\n" \
"bif v4.16b, %[vzero].16b, %[mask1].16b\n" \
"bif v5.16b, %[vzero].16b, %[mask2].16b\n" \
\
"bif v6.16b, %[zero].16b, %[mask1].16b\n" \
"bif v7.16b, %[zero].16b, %[mask2].16b\n" \
"bif v6.16b, %[vzero].16b, %[mask1].16b\n" \
"bif v7.16b, %[vzero].16b, %[mask2].16b\n" \
\
"ext v8.16b, v0.16b, v1.16b, #4\n" \
"ext v9.16b, v0.16b, v1.16b, #8\n" \
......@@ -849,7 +1224,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
// "st1 {v12.4s}, [%[out1]]\n" \
// "st1 {v13.4s}, [%[out2]]\n" \
#else
#define INIT_S1 \
"pld [%[din0_ptr]] @ preload data\n" \
......@@ -1129,6 +1503,66 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define LEFT_RESULT_S1_RELU6 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vdup.32 q4, %[bias_val] @ and \n" \
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define LEFT_RESULT_S1_LEAKY_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
"vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vbif q4, q6, q15 @ choose \n" \
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q5, q14 \n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vbif q5, q6, q7 @ choose \n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define MID_RESULT_S1_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
......@@ -1157,6 +1591,69 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"bne 1b @ jump to main loop start point\n"
#define MID_RESULT_S1_RELU6 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[six_ptr]]! @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define MID_RESULT_S1_LEAKY_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
\
"vbif q4, q6, q15 @ choose \n" \
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
\
"vbif q5, q6, q7 @ choose \n" \
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define RIGHT_RESULT_S1_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
......@@ -1178,6 +1675,58 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define RIGHT_RESULT_S1_RELU6 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmin.f32 q4, q4, q14 @ relu6 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmin.f32 q5, q5, q14 @ relu6 \n" \
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define RIGHT_RESULT_S1_LEAKY_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q4, q14 \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vbif q4, q6, q15 @ choose \n" \
\
"vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q6, q5, q14 \n" \
\
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
"vbif q5, q6, q7 @ choose \n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define COMPUTE_S_S1 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
......@@ -1251,6 +1800,36 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define RESULT_S_S1_RELU6 \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vld1.32 {d20-d21}, [%[six_ptr]] \n" \
"vmax.f32 q14, q14, %q[vzero]\n" \
"vmax.f32 q15, q15, %q[vzero]\n" \
\
"vmin.f32 q14, q14, q10 \n" \
"vmin.f32 q15, q15, q10 \n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define RESULT_S_S1_LEAKY_RELU \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vld1.32 {d18-d19}, [%[scale_ptr]] \n" \
"vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \
"vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \
"vmul.f32 q12, q14, q9 \n" \
"vmul.f32 q13, q15, q9 \n" \
\
"vbif q14, q10, q12 \n" \
"vbif q15, q11, q13 \n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
......@@ -1333,6 +1912,413 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"vadd.f32 q15, q5, q9 @ q4 += q10 \n"
#endif
#ifdef __aarch64__
void act_switch_3x3s1p1(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
const float *din_ptr4,
const float *din_ptr5,
float *doutr0,
float *doutr1,
float *doutr2,
float *doutr3,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
}
#else
void act_switch_3x3s1p1(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[six_ptr] "r"(vsix),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[scale_ptr] "r"(vscale),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
}
#endif
// clang-format on
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
......@@ -1349,6 +2335,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
......@@ -1486,106 +2473,25 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
}
int cnt = cnt_col;
if (flag_relu) {
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
act_switch_3x3s1p1(din_ptr0,
din_ptr1,
din_ptr2,
din_ptr3,
din_ptr4,
din_ptr5,
doutr0,
doutr1,
doutr2,
doutr3,
wr0,
wr1,
wr2,
vmask,
rmask,
vzero,
vbias,
cnt,
act_param);
dout_ptr = dout_ptr + 4 * w_out;
}
#else
......@@ -1598,7 +2504,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
// unsigned int* rst_mask = rmask;
if (i == 0) {
din_ptr0 = zero_ptr;
......@@ -1635,77 +2540,314 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int cnt = cnt_col;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
if (flag_relu) {
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
act_switch_3x3s1p1(din_ptr0,
din_ptr1,
din_ptr2,
din_ptr3,
doutr0,
doutr1,
wr0,
wr1,
wr2,
vmask_ptr,
rmask_ptr,
vzero,
bias_val,
cnt,
act_param);
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void act_switch_3x3s1p1_s(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp,
float32x4_t vzero,
float32x4_t wbias,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[six_ptr] "r"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
......@@ -1722,6 +2864,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
......@@ -1772,7 +2915,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
......@@ -1782,127 +2924,19 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
default:
break;
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
act_switch_3x3s1p1_s(dr0,
dr1,
dr2,
dr3,
out_buf1,
out_buf2,
wr0,
wr1,
wr2,
vmask_rp,
vzero,
wbias,
act_param);
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
......@@ -1916,6 +2950,490 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
} // end of processing batchs
}
#ifdef __aarch64__
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
const float *din_ptr4,
const float *din_ptr5,
float *doutr0,
float *doutr1,
float *doutr2,
float *doutr3,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
int remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
}
#else
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
int remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
......@@ -1932,6 +3450,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
......@@ -2060,15 +3579,16 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
}
int cnt = tile_w;
/*
if (flag_relu) {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
......@@ -2123,12 +3643,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
......@@ -2181,6 +3701,27 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
"v24",
"v25");
}
*/
act_switch_3x3s1p0(din_ptr0,
din_ptr1,
din_ptr2,
din_ptr3,
din_ptr4,
din_ptr5,
doutr0,
doutr1,
doutr2,
doutr3,
wr0,
wr1,
wr2,
vmask,
rmask,
vzero,
vbias,
cnt,
remain,
act_param);
dout_ptr = dout_ptr + 4 * w_out;
}
#else
......@@ -2219,6 +3760,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
/*
if (flag_relu) {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
......@@ -2301,13 +3843,328 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
"q13",
"q14",
"q15");
}
}*/
act_switch_3x3s1p0(din_ptr0,
din_ptr1,
din_ptr2,
din_ptr3,
doutr0,
doutr1,
wr0,
wr1,
wr2,
vmask_ptr,
rmask_ptr,
vzero,
bias_val,
cnt,
remain,
act_param);
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void act_switch_3x3s1p0_s(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
float32x4_t vzero,
float32x4_t wbias,
unsigned int *vmask_ptr,
float bias_val,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
......@@ -2324,6 +4181,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
......@@ -2355,15 +4213,22 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
// #ifdef __aarch64__
// float32x4_t wbias;
// if (flag_bias) {
// wbias = vdupq_n_f32(bias[i]);
// } else {
// wbias = vdupq_n_f32(0.f);
// }
// #endif // __aarch64__
float32x4_t wbias;
float bias_val = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_val = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
......@@ -2396,135 +4261,154 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
break;
}
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
/*
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
*/
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
act_switch_3x3s1p0_s(dr0,
dr1,
dr2,
dr3,
out_buf1,
out_buf2,
wr0,
wr1,
wr2,
vmask_rp1,
vmask_rp2,
vzero,
wbias,
vmask_ptr,
bias_val,
act_param);
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
......
......@@ -25,6 +25,785 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
// clang-format off
#ifdef __aarch64__
#define COMPUTE \
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \
"ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ \
"ldp q4, q5, [%[inr0]]\n" /* load input r0*/ \
"ldp q10, q11, [%[inr1]]\n" /* load input r1*/ \
/* r0, r1, mul w0, get out r0, r1 */ \
"fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ \
"fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ \
"fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ \
"fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ \
"fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ \
"fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ \
"fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ \
"fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ \
/* r0, r1, mul w1, get out r0, r1 */ \
"fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ \
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ \
"fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ \
"fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ \
"fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ \
"fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ \
"fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ \
"fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ \
/* r0, r1, mul w2, get out r0, r1 */ \
"fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ \
"fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ \
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ \
"fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ \
"ldp q4, q5, [%[inr2]]\n" /* load input r2*/ \
"fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ \
"fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ \
"fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ \
"fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ \
/* r1, r2, mul w3, get out r0, r1 */ \
"fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ \
"fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ \
"fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ \
"fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ \
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ \
"fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ \
"fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ \
"fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ \
/* r1, r2, mul w4, get out r0, r1 */ \
"fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ \
"ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ \
"fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ \
"fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ \
"fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ \
"ldp x0, x1, [%[outl]] \n" \
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ \
"fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ \
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ \
"fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ \
/* r1, r2, mul w5, get out r0, r1 */ \
"fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ \
"fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ \
"ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ \
"fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ \
"fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ \
"ldp q10, q11, [%[inr3]]\n" /* load input r3*/ \
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ \
"fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ \
"fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ \
"fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ \
/* r2, r3, mul w6, get out r0, r1 */ \
"fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ \
"fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ \
"fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ \
"fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ \
"ldp x2, x3, [%[outl], #16] \n" \
"fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ \
"fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ \
"fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ \
"fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ \
/* r2, r3, mul w7, get out r0, r1 */ \
"fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ \
"fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ \
"fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ \
"fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ \
"ldp x4, x5, [%[outl], #32] \n" \
"fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ \
"fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ \
"fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ \
"fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ \
/* r2, r3, mul w8, get out r0, r1 */ \
"fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ \
"fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ \
"fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ \
"fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ \
"ldp x6, x7, [%[outl], #48] \n" \
"fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ \
"fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ \
"fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ \
"fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ \
\
"fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ \
"fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ \
"fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ \
"fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ \
"fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ \
"fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ \
"fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ \
"fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ \
/* transpose */ \
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ \
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ \
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ \
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ \
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ \
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ \
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ \
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ \
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ \
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ \
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ \
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ \
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
#define RELU \
"movi v0.4s, #0\n" /* for relu */ \
"ldr x0, [%[outl], #80]\n" \
"fmax v15.4s, v15.4s, v0.4s\n" \
"fmax v16.4s, v16.4s, v0.4s\n" \
"fmax v17.4s, v17.4s, v0.4s\n" \
"fmax v18.4s, v18.4s, v0.4s\n" \
"ld1 {v1.4s}, [x0]\n" \
"fmax v19.4s, v19.4s, v0.4s\n" \
"fmax v20.4s, v20.4s, v0.4s\n" \
"fmax v21.4s, v21.4s, v0.4s\n" \
"fmax v22.4s, v22.4s, v0.4s\n" \
"ldr x0, [%[outl]]\n" \
#define RELU6 \
"fmin v15.4s, v15.4s, v1.4s\n" \
"fmin v16.4s, v16.4s, v1.4s\n" \
"fmin v17.4s, v17.4s, v1.4s\n" \
"fmin v18.4s, v18.4s, v1.4s\n" \
"fmin v19.4s, v19.4s, v1.4s\n" \
"fmin v20.4s, v20.4s, v1.4s\n" \
"fmin v21.4s, v21.4s, v1.4s\n" \
"fmin v22.4s, v22.4s, v1.4s\n"
#define LEAKY_RELU \
"movi v0.4s, #0\n" /* for relu */ \
"ldr x0, [%[outl], #88]\n" \
"cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \
"ld1 {v9.4s}, [x0] \n" \
"cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \
"ldr x0, [%[outl]] \n" \
"fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \
"fmul v7.4s, v17.4s, v9.4s \n" /* mul */ \
"fmul v8.4s, v18.4s, v9.4s \n" /* mul */ \
"bif v15.16b, v5.16b, v1.16b \n" /* choose*/ \
"bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \
"bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \
"fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \
"fmul v8.4s, v22.4s, v9.4s \n" /* mul */ \
"bif v19.16b, v5.16b, v1.16b \n" /* choose*/ \
"bif v20.16b, v6.16b, v2.16b \n" /* choose*/ \
"bif v21.16b, v7.16b, v3.16b \n" /* choose*/ \
"bif v22.16b, v8.16b, v4.16b \n" /* choose*/
#define STORE \
"cbnz %w[flag_mask], 1f\n" \
"str q15, [x0]\n" /* save outc00 */ \
"str q16, [x4]\n" /* save outc01 */ \
"str q17, [x1]\n" /* save outc10 */ \
"str q18, [x5]\n" /* save outc11 */ \
"str q19, [x2]\n" /* save outc20 */ \
"str q20, [x6]\n" /* save outc21 */ \
"str q21, [x3]\n" /* save outc30 */ \
"str q22, [x7]\n" /* save outc31 */ \
"b 2f\n" \
"1:\n" \
"str q15, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q17, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q19, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q21, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q16, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q18, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q20, [%[out]], #16 \n" /* save remain to pre_out */ \
"str q22, [%[out]], #16 \n" /* save remain to pre_out */ \
"2:\n"
#else
#define COMPUTE \
/* load weights */ \
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" \
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" \
/* load r0, r1 */ \
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" \
"vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" \
/* main loop */ \
"0: @ main loop\n" \
/* mul r0 with w0, w1, w2, get out r0 */ \
"vmul.f32 q8, q5, q0 @ w0 * inr00\n" \
"vmul.f32 q9, q5, q1 @ w0 * inr01\n" \
"vmul.f32 q10, q5, q2 @ w0 * inr02\n" \
"vmul.f32 q11, q5, q3 @ w0 * inr03\n" \
"vmla.f32 q8, q6, q1 @ w1 * inr01\n" \
"vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" \
"vmla.f32 q9, q6, q2 @ w1 * inr02\n" \
"vmla.f32 q10, q6, q3 @ w1 * inr03\n" \
"vmla.f32 q11, q6, q0 @ w1 * inr04\n" \
"vmla.f32 q8, q7, q2 @ w2 * inr02\n" \
"vmla.f32 q9, q7, q3 @ w2 * inr03\n" \
"vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" \
"vmla.f32 q10, q7, q0 @ w2 * inr04\n" \
"vmla.f32 q11, q7, q1 @ w2 * inr05\n" \
"vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" \
"vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" \
/* mul r1 with w0-w5, get out r0, r1 */ \
"vmul.f32 q12, q5, q2 @ w0 * inr10\n" \
"vmul.f32 q13, q5, q3 @ w0 * inr11\n" \
"vmul.f32 q14, q5, q0 @ w0 * inr12\n" \
"vmul.f32 q15, q5, q1 @ w0 * inr13\n" \
"vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" \
"vmla.f32 q8, q4, q2 @ w3 * inr10\n" \
"vmla.f32 q9, q4, q3 @ w3 * inr11\n" \
"vmla.f32 q10, q4, q0 @ w3 * inr12\n" \
"vmla.f32 q11, q4, q1 @ w3 * inr13\n" \
/* mul r1 with w1, w4, get out r1, r0 */ \
"vmla.f32 q8, q5, q3 @ w4 * inr11\n" \
"vmla.f32 q12, q6, q3 @ w1 * inr11\n" \
"vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" \
"vmla.f32 q9, q5, q0 @ w4 * inr12\n" \
"vmla.f32 q13, q6, q0 @ w1 * inr12\n" \
"vmla.f32 q10, q5, q1 @ w4 * inr13\n" \
"vmla.f32 q14, q6, q1 @ w1 * inr13\n" \
"vmla.f32 q11, q5, q2 @ w4 * inr14\n" \
"vmla.f32 q15, q6, q2 @ w1 * inr14\n" \
"vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" \
/* mul r1 with w2, w5, get out r1, r0 */ \
"vmla.f32 q12, q7, q0 @ w2 * inr12\n" \
"vmla.f32 q13, q7, q1 @ w2 * inr13\n" \
"vmla.f32 q8, q6, q0 @ w5 * inr12\n" \
"vmla.f32 q9, q6, q1 @ w5 * inr13\n" \
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" \
"vmla.f32 q14, q7, q2 @ w2 * inr14\n" \
"vmla.f32 q15, q7, q3 @ w2 * inr15\n" \
"vmla.f32 q10, q6, q2 @ w5 * inr14\n" \
"vmla.f32 q11, q6, q3 @ w5 * inr15\n" \
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" \
"vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" \
/* mul r2 with w3-w8, get out r0, r1 */ \
"vmla.f32 q12, q4, q0 @ w3 * inr20\n" \
"vmla.f32 q13, q4, q1 @ w3 * inr21\n" \
"vmla.f32 q14, q4, q2 @ w3 * inr22\n" \
"vmla.f32 q15, q4, q3 @ w3 * inr23\n" \
"vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" \
"vmla.f32 q8, q7, q0 @ w6 * inr20\n" \
"vmla.f32 q9, q7, q1 @ w6 * inr21\n" \
"vmla.f32 q10, q7, q2 @ w6 * inr22\n" \
"vmla.f32 q11, q7, q3 @ w6 * inr23\n" \
/* mul r2 with w4, w7, get out r1, r0 */ \
"vmla.f32 q8, q4, q1 @ w7 * inr21\n" \
"vmla.f32 q12, q5, q1 @ w4 * inr21\n" \
"vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" \
"vmla.f32 q9, q4, q2 @ w7 * inr22\n" \
"vmla.f32 q13, q5, q2 @ w4 * inr22\n" \
"vmla.f32 q10, q4, q3 @ w7 * inr23\n" \
"vmla.f32 q14, q5, q3 @ w4 * inr23\n" \
"vmla.f32 q11, q4, q0 @ w7 * inr24\n" \
"vmla.f32 q15, q5, q0 @ w4 * inr24\n" \
"vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" \
/* mul r1 with w5, w8, get out r1, r0 */ \
"vmla.f32 q12, q6, q2 @ w5 * inr22\n" \
"vmla.f32 q13, q6, q3 @ w5 * inr23\n" \
"vmla.f32 q8, q5, q2 @ w8 * inr22\n" \
"vmla.f32 q9, q5, q3 @ w8 * inr23\n" \
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" \
"ldr r4, [%[outl], #32] @ load bias addr to r4\n" \
"vmla.f32 q14, q6, q0 @ w5 * inr24\n" \
"vmla.f32 q15, q6, q1 @ w5 * inr25\n" \
"vmla.f32 q10, q5, q0 @ w8 * inr24\n" \
"vmla.f32 q11, q5, q1 @ w8 * inr25\n" \
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" \
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \
/* mul r3 with w6, w7, w8, get out r1 */ \
"vmla.f32 q12, q7, q2 @ w6 * inr30\n" \
"vmla.f32 q13, q7, q3 @ w6 * inr31\n" \
"vmla.f32 q14, q7, q0 @ w6 * inr32\n" \
"vmla.f32 q15, q7, q1 @ w6 * inr33\n" \
"vmla.f32 q12, q4, q3 @ w7 * inr31\n" \
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" \
"vld1.32 {d12-d13}, [r4] @ load bias\n" \
"vmla.f32 q13, q4, q0 @ w7 * inr32\n" \
"vmla.f32 q14, q4, q1 @ w7 * inr33\n" \
"vmla.f32 q15, q4, q2 @ w7 * inr34\n" \
"ldr r0, [%[outl]] @ load outc00 to r0\n" \
"vmla.f32 q12, q5, q0 @ w8 * inr32\n" \
"vmla.f32 q13, q5, q1 @ w8 * inr33\n" \
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \
"vmla.f32 q14, q5, q2 @ w8 * inr34\n" \
"vmla.f32 q15, q5, q3 @ w8 * inr35\n" \
"ldr r1, [%[outl], #4] @ load outc10 to r1\n" \
"vadd.f32 q8, q8, q6 @ r00 add bias\n" \
"vadd.f32 q9, q9, q6 @ r01 add bias\n" \
"vadd.f32 q10, q10, q6 @ r02 add bias\n" \
"vadd.f32 q11, q11, q6 @ r03 add bias\n" \
"ldr r2, [%[outl], #8] @ load outc20 to r2\n" \
"vadd.f32 q12, q12, q6 @ r10 add bias\n" \
"vadd.f32 q13, q13, q6 @ r11 add bias\n" \
"vadd.f32 q14, q14, q6 @ r12 add bias\n" \
"vadd.f32 q15, q15, q6 @ r13 add bias\n" \
"ldr r3, [%[outl], #12] @ load outc30 to r3\n" \
"vmov.u32 q7, #0 @ mov zero to q7\n"
#define RELU \
"vmax.f32 q8, q8, q7 @ r00 relu\n" \
"vmax.f32 q9, q9, q7 @ r01 relu\n" \
"vmax.f32 q10, q10, q7 @ r02 relu\n" \
"vmax.f32 q11, q11, q7 @ r03 relu\n" \
"vmax.f32 q12, q12, q7 @ r10 relu\n" \
"vmax.f32 q13, q13, q7 @ r11 relu\n" \
"vmax.f32 q14, q14, q7 @ r12 relu\n" \
"vmax.f32 q15, q15, q7 @ r13 relu\n"
#define RELU6 \
"ldr r4, [%[outl], #40] @ load six to r4\n" \
"vld1.32 {d12-d13}, [r4] @load data \n" \
"vmin.f32 q8, q8, q6 @ r00 relu\n" \
"vmin.f32 q9, q9, q6 @ r01 relu\n" \
"vmin.f32 q10, q10, q6 @ r02 relu\n" \
"vmin.f32 q11, q11, q6 @ r03 relu\n" \
"vmin.f32 q12, q12, q6 @ r10 relu\n" \
"vmin.f32 q13, q13, q6 @ r11 relu\n" \
"vmin.f32 q14, q14, q6 @ r12 relu\n" \
"vmin.f32 q15, q15, q6 @ r13 relu\n"
#define LEAKY_RELU \
"ldr r4, [%[outl], #44] @ load scale to r4\n" \
"vld1.32 {d12-d13}, [r4] @load data \n" \
"vcge.f32 q0, q8, q7 @ q0 > 0 \n" \
"vcge.f32 q1, q9, q7 @ q0 > 0 \n" \
"vmul.f32 q4, q8, q6 \n" \
"vmul.f32 q5, q9, q6 \n" \
"vcge.f32 q2, q10, q7 @ q0 > 0 \n" \
"vcge.f32 q3, q11, q7 @ q0 > 0 \n" \
"vbif q8, q4, q0 @ choose \n" \
"vbif q9, q5, q1 @ choose \n" \
"vmul.f32 q4, q10, q6 \n" \
"vmul.f32 q5, q11, q6 \n" \
"vbif q10, q4, q2 @ choose \n" \
"vbif q11, q5, q3 @ choose \n" \
"vcge.f32 q0, q12, q7 @ q0 > 0 \n" \
"vcge.f32 q1, q13, q7 @ q0 > 0 \n" \
"vmul.f32 q4, q12, q6 \n" \
"vmul.f32 q5, q13, q6 \n" \
"vcge.f32 q2, q14, q7 @ q0 > 0 \n" \
"vcge.f32 q3, q15, q7 @ q0 > 0 \n" \
"vbif q12, q4, q0 @ choose \n" \
"vbif q13, q5, q1 @ choose \n" \
"vmul.f32 q4, q14, q6 \n" \
"vmul.f32 q5, q15, q6 \n" \
"vbif q14, q4, q2 @ choose \n" \
"vbif q15, q5, q3 @ choose \n"
#define STORE \
"ldr r4, [%[outl], #16] @ load outc01 to r4\n" \
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" \
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \
"ldr r5, [%[outl], #20] @ load outc11 to r5\n" \
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" \
"cmp %[flag_mask], #0 @ cmp flag mask\n" \
"bne 2f\n" \
"vst1.32 {d16-d17}, [r0] @ save outc00\n" \
"vst1.32 {d18-d19}, [r1] @ save outc10\n" \
"vst1.32 {d20-d21}, [r2] @ save outc20\n" \
"vst1.32 {d22-d23}, [r3] @ save outc30\n" \
"vst1.32 {d24-d25}, [r4] @ save outc01\n" \
"vst1.32 {d26-d27}, [r5] @ save outc11\n" \
"ldr r0, [%[outl], #24] @ load outc21 to r0\n" \
"ldr r1, [%[outl], #28] @ load outc31 to r1\n" \
"vst1.32 {d28-d29}, [r0] @ save outc21\n" \
"vst1.32 {d30-d31}, [r1] @ save outc31\n" \
"b 3f @ branch end\n" \
"2: \n" \
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" \
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" \
"3: \n"
#endif
// clang-format on
void act_switch_3x3s1(const float* inr0,
const float* inr1,
const float* inr2,
const float* inr3,
float* out0,
const float* weight_c,
float flag_mask,
void* outl_ptr,
float32x4_t w0,
float32x4_t w1,
float32x4_t w2,
float32x4_t w3,
float32x4_t w4,
float32x4_t w5,
float32x4_t w6,
float32x4_t w7,
float32x4_t w8,
float32x4_t vbias,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(COMPUTE LEAKY_RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[inr3] "+r"(inr3),
[out] "+r"(out0)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[vbias] "w"(vbias),
[outl] "r"(outl_ptr),
[flag_mask] "r"(flag_mask)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"x0",
"x1",
"x2",
"x3",
"x4",
"x5",
"x6",
"x7");
#else
asm volatile(COMPUTE STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[out0] "+r"(out0),
[wc0] "+r"(weight_c)
: [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15",
"r0",
"r1",
"r2",
"r3",
"r4",
"r5");
#endif
}
}
void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data,
int bs,
......@@ -37,6 +816,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int threads = ctx->threads();
......@@ -78,6 +858,31 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
float six_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f};
float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
break;
case lite_api::ActivationType::kRelu6:
six_ptr[0] = act_param.Relu_clipped_coef;
six_ptr[1] = act_param.Relu_clipped_coef;
six_ptr[2] = act_param.Relu_clipped_coef;
six_ptr[3] = act_param.Relu_clipped_coef;
break;
case lite_api::ActivationType::kLeakyRelu:
scale_ptr[0] = act_param.Leaky_relu_alpha;
scale_ptr[1] = act_param.Leaky_relu_alpha;
scale_ptr[2] = act_param.Leaky_relu_alpha;
scale_ptr[3] = act_param.Leaky_relu_alpha;
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
......@@ -147,6 +952,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21 = ptr_write;
outc31 = ptr_write;
}
float* outl[] = {outc00,
outc10,
outc20,
......@@ -156,361 +962,54 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outc21,
outc31,
reinterpret_cast<float*>(bias_local),
reinterpret_cast<float*>(flag_relu)};
reinterpret_cast<float*>(relu_ptr),
reinterpret_cast<float*>(six_ptr),
reinterpret_cast<float*>(scale_ptr)};
void* outl_ptr = reinterpret_cast<void*>(outl);
for (int w = 0; w < w_loop; ++w) {
bool flag_mask = (w == w_loop - 1) && flag_remain;
float* out0 = pre_out;
// clang-format off
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/
"ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/
"ldp q4, q5, [%[inr0]]\n" /* load input r0*/
"ldp q10, q11, [%[inr1]]\n" /* load input r1*/
/* r0, r1, mul w0, get out r0, r1 */
"fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/
"fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/
"fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/
"fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/
"fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/
"fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/
"fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/
"fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/
/* r0, r1, mul w1, get out r0, r1 */
"fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/
"fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/
"fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/
"fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/
"fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/
"fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/
"fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/
"fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/
/* r0, r1, mul w2, get out r0, r1 */
"fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/
"fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/
"fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/
"fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/
"ldp q4, q5, [%[inr2]]\n" /* load input r2*/
"fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/
"fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/
"fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/
"fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/
/* r1, r2, mul w3, get out r0, r1 */
"fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/
"fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/
"fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/
"fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/
"fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/
"fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/
"fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/
/* r1, r2, mul w4, get out r0, r1 */
"fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/
"ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/
"fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/
"fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/
"fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/
"ldp x0, x1, [%[outl]] \n"
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/
"fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/
"fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/
"fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/
/* r1, r2, mul w5, get out r0, r1 */
"fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/
"fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/
"ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/
"fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/
"fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/
"ldp q10, q11, [%[inr3]]\n" /* load input r3*/
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/
"fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/
"fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/
"fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/
/* r2, r3, mul w6, get out r0, r1 */
"fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/
"fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/
"fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/
"fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/
"ldp x2, x3, [%[outl], #16] \n"
"fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/
"fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/
"fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/
"fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/
/* r2, r3, mul w7, get out r0, r1 */
"fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/
"fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/
"fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/
"fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/
"ldp x4, x5, [%[outl], #32] \n"
"fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/
"fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/
"fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/
"fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/
/* r2, r3, mul w8, get out r0, r1 */
"fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/
"fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/
"fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/
"fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/
"ldp x6, x7, [%[outl], #48] \n"
"fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/
"fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/
"fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/
"fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/
"fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */
"fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */
"fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */
"fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */
"fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */
"fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */
"fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */
"fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */
/* transpose */
"trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/
"trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/
"trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/
"trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/
"trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/
"trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
"trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/
"trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/
"trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/
"trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v15.4s, v15.4s, v0.4s\n"
"fmax v16.4s, v16.4s, v0.4s\n"
"fmax v17.4s, v17.4s, v0.4s\n"
"fmax v18.4s, v18.4s, v0.4s\n"
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
"0:\n"
"cbnz %w[flag_mask], 1f\n"
"str q15, [x0]\n" /* save outc00 */
"str q16, [x4]\n" /* save outc01 */
"str q17, [x1]\n" /* save outc10 */
"str q18, [x5]\n" /* save outc11 */
"str q19, [x2]\n" /* save outc20 */
"str q20, [x6]\n" /* save outc21 */
"str q21, [x3]\n" /* save outc30 */
"str q22, [x7]\n" /* save outc31 */
"b 2f\n"
"1:\n"
"str q15, [%[out]], #16 \n" /* save remain to pre_out */
"str q17, [%[out]], #16 \n" /* save remain to pre_out */
"str q19, [%[out]], #16 \n" /* save remain to pre_out */
"str q21, [%[out]], #16 \n" /* save remain to pre_out */
"str q16, [%[out]], #16 \n" /* save remain to pre_out */
"str q18, [%[out]], #16 \n" /* save remain to pre_out */
"str q20, [%[out]], #16 \n" /* save remain to pre_out */
"str q22, [%[out]], #16 \n" /* save remain to pre_out */
"2:\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2), [inr3] "+r"(inr3),
[out]"+r"(out0)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[vbias]"w" (vbias), [outl] "r" (outl_ptr),
[flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v9", "v10", "v11", "v15",
"v16","v17","v18","v19","v20","v21","v22",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7"
);
act_switch_3x3s1(inr0,
inr1,
inr2,
inr3,
out0,
weight_c,
flag_mask,
outl_ptr,
w0,
w1,
w2,
w3,
w4,
w5,
w6,
w7,
w8,
vbias,
act_param);
#else
asm volatile(
/* load weights */
"vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n"
/* load r0, r1 */
"vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n"
"vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n"
/* main loop */
"0: @ main loop\n"
/* mul r0 with w0, w1, w2, get out r0 */
"vmul.f32 q8, q5, q0 @ w0 * inr00\n"
"vmul.f32 q9, q5, q1 @ w0 * inr01\n"
"vmul.f32 q10, q5, q2 @ w0 * inr02\n"
"vmul.f32 q11, q5, q3 @ w0 * inr03\n"
"vmla.f32 q8, q6, q1 @ w1 * inr01\n"
"vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n"
"vmla.f32 q9, q6, q2 @ w1 * inr02\n"
"vmla.f32 q10, q6, q3 @ w1 * inr03\n"
"vmla.f32 q11, q6, q0 @ w1 * inr04\n"
"vmla.f32 q8, q7, q2 @ w2 * inr02\n"
"vmla.f32 q9, q7, q3 @ w2 * inr03\n"
"vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n"
"vmla.f32 q10, q7, q0 @ w2 * inr04\n"
"vmla.f32 q11, q7, q1 @ w2 * inr05\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n"
/* mul r1 with w0-w5, get out r0, r1 */
"vmul.f32 q12, q5, q2 @ w0 * inr10\n"
"vmul.f32 q13, q5, q3 @ w0 * inr11\n"
"vmul.f32 q14, q5, q0 @ w0 * inr12\n"
"vmul.f32 q15, q5, q1 @ w0 * inr13\n"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n"
"vmla.f32 q8, q4, q2 @ w3 * inr10\n"
"vmla.f32 q9, q4, q3 @ w3 * inr11\n"
"vmla.f32 q10, q4, q0 @ w3 * inr12\n"
"vmla.f32 q11, q4, q1 @ w3 * inr13\n"
/* mul r1 with w1, w4, get out r1, r0 */
"vmla.f32 q8, q5, q3 @ w4 * inr11\n"
"vmla.f32 q12, q6, q3 @ w1 * inr11\n"
"vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n"
"vmla.f32 q9, q5, q0 @ w4 * inr12\n"
"vmla.f32 q13, q6, q0 @ w1 * inr12\n"
"vmla.f32 q10, q5, q1 @ w4 * inr13\n"
"vmla.f32 q14, q6, q1 @ w1 * inr13\n"
"vmla.f32 q11, q5, q2 @ w4 * inr14\n"
"vmla.f32 q15, q6, q2 @ w1 * inr14\n"
"vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n"
/* mul r1 with w2, w5, get out r1, r0 */
"vmla.f32 q12, q7, q0 @ w2 * inr12\n"
"vmla.f32 q13, q7, q1 @ w2 * inr13\n"
"vmla.f32 q8, q6, q0 @ w5 * inr12\n"
"vmla.f32 q9, q6, q1 @ w5 * inr13\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n"
"vmla.f32 q14, q7, q2 @ w2 * inr14\n"
"vmla.f32 q15, q7, q3 @ w2 * inr15\n"
"vmla.f32 q10, q6, q2 @ w5 * inr14\n"
"vmla.f32 q11, q6, q3 @ w5 * inr15\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n"
"vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n"
/* mul r2 with w3-w8, get out r0, r1 */
"vmla.f32 q12, q4, q0 @ w3 * inr20\n"
"vmla.f32 q13, q4, q1 @ w3 * inr21\n"
"vmla.f32 q14, q4, q2 @ w3 * inr22\n"
"vmla.f32 q15, q4, q3 @ w3 * inr23\n"
"vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n"
"vmla.f32 q8, q7, q0 @ w6 * inr20\n"
"vmla.f32 q9, q7, q1 @ w6 * inr21\n"
"vmla.f32 q10, q7, q2 @ w6 * inr22\n"
"vmla.f32 q11, q7, q3 @ w6 * inr23\n"
/* mul r2 with w4, w7, get out r1, r0 */
"vmla.f32 q8, q4, q1 @ w7 * inr21\n"
"vmla.f32 q12, q5, q1 @ w4 * inr21\n"
"vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n"
"vmla.f32 q9, q4, q2 @ w7 * inr22\n"
"vmla.f32 q13, q5, q2 @ w4 * inr22\n"
"vmla.f32 q10, q4, q3 @ w7 * inr23\n"
"vmla.f32 q14, q5, q3 @ w4 * inr23\n"
"vmla.f32 q11, q4, q0 @ w7 * inr24\n"
"vmla.f32 q15, q5, q0 @ w4 * inr24\n"
"vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n"
/* mul r1 with w5, w8, get out r1, r0 */
"vmla.f32 q12, q6, q2 @ w5 * inr22\n"
"vmla.f32 q13, q6, q3 @ w5 * inr23\n"
"vmla.f32 q8, q5, q2 @ w8 * inr22\n"
"vmla.f32 q9, q5, q3 @ w8 * inr23\n"
"vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n"
"ldr r4, [%[outl], #32] @ load bias addr to r4\n"
"vmla.f32 q14, q6, q0 @ w5 * inr24\n"
"vmla.f32 q15, q6, q1 @ w5 * inr25\n"
"vmla.f32 q10, q5, q0 @ w8 * inr24\n"
"vmla.f32 q11, q5, q1 @ w8 * inr25\n"
"vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
/* mul r3 with w6, w7, w8, get out r1 */
"vmla.f32 q12, q7, q2 @ w6 * inr30\n"
"vmla.f32 q13, q7, q3 @ w6 * inr31\n"
"vmla.f32 q14, q7, q0 @ w6 * inr32\n"
"vmla.f32 q15, q7, q1 @ w6 * inr33\n"
"vmla.f32 q12, q4, q3 @ w7 * inr31\n"
"vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n"
"vld1.32 {d12-d13}, [r4] @ load bias\n"
"vmla.f32 q13, q4, q0 @ w7 * inr32\n"
"vmla.f32 q14, q4, q1 @ w7 * inr33\n"
"vmla.f32 q15, q4, q2 @ w7 * inr34\n"
"ldr r0, [%[outl]] @ load outc00 to r0\n"
"vmla.f32 q12, q5, q0 @ w8 * inr32\n"
"vmla.f32 q13, q5, q1 @ w8 * inr33\n"
"ldr r5, [%[outl], #36] @ load flag_relu to r5\n"
"vmla.f32 q14, q5, q2 @ w8 * inr34\n"
"vmla.f32 q15, q5, q3 @ w8 * inr35\n"
"ldr r1, [%[outl], #4] @ load outc10 to r1\n"
"vadd.f32 q8, q8, q6 @ r00 add bias\n"
"vadd.f32 q9, q9, q6 @ r01 add bias\n"
"vadd.f32 q10, q10, q6 @ r02 add bias\n"
"vadd.f32 q11, q11, q6 @ r03 add bias\n"
"ldr r2, [%[outl], #8] @ load outc20 to r2\n"
"vadd.f32 q12, q12, q6 @ r10 add bias\n"
"vadd.f32 q13, q13, q6 @ r11 add bias\n"
"vadd.f32 q14, q14, q6 @ r12 add bias\n"
"vadd.f32 q15, q15, q6 @ r13 add bias\n"
"ldr r3, [%[outl], #12] @ load outc30 to r3\n"
"vmov.u32 q7, #0 @ mov zero to q7\n"
"cmp r5, #0 @ cmp flag relu\n"
"beq 1f @ skip relu\n"
"vmax.f32 q8, q8, q7 @ r00 relu\n"
"vmax.f32 q9, q9, q7 @ r01 relu\n"
"vmax.f32 q10, q10, q7 @ r02 relu\n"
"vmax.f32 q11, q11, q7 @ r03 relu\n"
"vmax.f32 q12, q12, q7 @ r10 relu\n"
"vmax.f32 q13, q13, q7 @ r11 relu\n"
"vmax.f32 q14, q14, q7 @ r12 relu\n"
"vmax.f32 q15, q15, q7 @ r13 relu\n"
"1:\n"
"ldr r4, [%[outl], #16] @ load outc01 to r4\n"
"vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n"
"vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n"
"vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n"
"vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n"
"ldr r5, [%[outl], #20] @ load outc11 to r5\n"
"vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n"
"vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n"
"vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n"
"vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n"
"cmp %[flag_mask], #0 @ cmp flag mask\n"
"bne 2f\n"
"vst1.32 {d16-d17}, [r0] @ save outc00\n"
"vst1.32 {d18-d19}, [r1] @ save outc10\n"
"vst1.32 {d20-d21}, [r2] @ save outc20\n"
"vst1.32 {d22-d23}, [r3] @ save outc30\n"
"vst1.32 {d24-d25}, [r4] @ save outc01\n"
"vst1.32 {d26-d27}, [r5] @ save outc11\n"
"ldr r0, [%[outl], #24] @ load outc21 to r0\n"
"ldr r1, [%[outl], #28] @ load outc31 to r1\n"
"vst1.32 {d28-d29}, [r0] @ save outc21\n"
"vst1.32 {d30-d31}, [r1] @ save outc31\n"
"b 3f @ branch end\n"
"2: \n"
"vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n"
"vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n"
"3: \n"
: [r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [r3] "+r"(inr3),
[out0] "+r"(out0), [wc0] "+r"(weight_c)
: [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr)
: "cc", "memory",
"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5"
);
#endif // __arch64__
// clang-format on
act_switch_3x3s1(inr0,
inr1,
inr2,
inr3,
out0,
weight_c,
flag_mask,
outl_ptr,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
vbias,
act_param);
#endif
outl[0] += 4;
outl[1] += 4;
outl[2] += 4;
......@@ -519,6 +1018,10 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
outl[5] += 4;
outl[6] += 4;
outl[7] += 4;
inr0 += 16;
inr1 += 16;
inr2 += 16;
inr3 += 16;
if (flag_mask) {
memcpy(outl[0] - 4, pre_out, remain * sizeof(float));
memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float));
......
......@@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float);
const int pad_w = paddings[2];
......@@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
#pragma omp parallel for num_threads(threads)
......@@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
}
}
......
......@@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n"
#define LEFT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
......@@ -244,53 +242,52 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"blt 1f \n"
#define MID_COMPUTE_S2 \
"2: \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
#define MID_COMPUTE_S2 \
"2: \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n"
#define MID_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" \
"ld1 {v1.4s}, [%[outptr1]] \n"
"ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n"
#define RIGHT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"4: \n"
#define LEFT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
......@@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"blt 1f \n"
#define MID_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......
......@@ -20,6 +20,7 @@
#include "lite/backends/arm/math/sgemm.h"
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/target_wrapper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
......@@ -28,6 +29,7 @@ namespace arm {
namespace math {
#define LITEMAX(a, b) ((a) > (b) ? (a) : (b))
#define LITEMIN(a, b) ((a) < (b) ? (a) : (b))
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
template <PrecisionType Ptype>
......@@ -589,7 +591,238 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
}
}
}
// clang-format off
#ifdef __aarch64__
#define NCHWC1_TRANS_FP32_COMPUTE \
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"movi v20.4s, #0 \n" /* for relu */ \
"1: \n" /* main loop*/
#define NCHWC1_TRANS_FP32_RELU \
"fmax v0.4s, v0.4s, v20.4s \n" /*relu*/ \
"fmax v1.4s, v1.4s, v20.4s \n" /*relu*/ \
"fmax v2.4s, v2.4s, v20.4s \n" /*relu*/ \
"fmax v3.4s, v3.4s, v20.4s \n" /*relu*/
#define NCHWC1_TRANS_FP32_RELU6 \
"fmin v0.4s, v0.4s, %[six].4s \n" /* relu6 */ \
"fmin v1.4s, v1.4s, %[six].4s \n" /* relu6 */ \
"fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC1_TRANS_FP32_LEAKY_RELU \
"cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \
"fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \
"fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \
"bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v7.16b \n" /* choose*/
#define NCHWC1_TRANS_FP32_STORE \
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \
\
"str q0, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
"str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
"str q3, [%[doutc0r0]], #16 \n" /* store c2r0*/ \
"ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
"ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \
\
"bne 1b \n" /* jump to main loop*/
#else
#define NCHWC1_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0 \n" \
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data, c0r0 \n" \
"vmov.u32 q15, #0 @ dump zero\n" \
"1: @ main loop\n"
#define NCHWC1_TRANS_FP32_RELU \
"vmax.f32 q0, q0, q15 @ relu\n" \
"vmax.f32 q1, q1, q15 @ relu\n" \
"vmax.f32 q2, q2, q15 @ relu\n" \
"vmax.f32 q3, q3, q15 @ relu\n"
#define NCHWC1_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n" \
"vmin.f32 q2, q2, %q[six] @ relu6 \n" \
"vmin.f32 q3, q3, %q[six] @ relu6 \n"
#define NCHWC1_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q7, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q8, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vmul.f32 q11, q2, %q[scale] \n" \
"vmul.f32 q12, q3, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n" \
"vbif q2, q11, q7 @ choose \n" \
"vbif q3, q12, q8 @ choose \n"
#define NCHWC1_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
\
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \
\
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \
\
"bne 1b @ jump to main loop\n"
#endif
// clang-format on
inline void act_switch_c1_fp32(const float* din_ptr,
float* doutc0_ptr,
int cnt_loop,
const operators::ActivationParam* act_param) {
if (act_param != nullptr && act_param->has_active) {
float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha);
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v20");
#else
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [six] "w"(six)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v20");
#else
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU
NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [six] "w"(six)
: "q0", "q1", "q2", "q3", "q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU
NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [scale] "w"(scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v20");
#else
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU
NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(scale)
: "q0",
"q1",
"q2",
"q3",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0", "v1", "v2", "v3", "v20");
#else
asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
}
}
/*wirte result in outputs
* input din: [n, c, h, w], output dout: [n, c, h, w]
*/
......@@ -605,13 +838,14 @@ inline bool write_to_output_c1_fp32(const float* din,
int height,
int width,
bool flag_relu,
float* trash_ptr) {
float* trash_ptr,
operators::ActivationParam* act_param) {
if (cs > channel) {
return true;
}
const int c1 = 1;
const int w4 = 4;
const int w4 = 16;
int size_c_out = width * height;
......@@ -623,98 +857,53 @@ inline bool write_to_output_c1_fp32(const float* din,
int w_round = we - ws;
int cnt = (width - ws) / w4;
int remain = (width - ws) % w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
const float* din_hei_ptr = ptr_din + i * w_round * c1;
if (cnt > 0) {
int cnt_loop = cnt;
if (flag_relu) {
#ifdef __aarch64__
asm volatile(
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"fmax v1.4s, v0.4s, v20.4s \n" /*relu*/
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v20");
#else
asm volatile(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, "
"c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vmax.f32 q1, q0, q15 @ relu\n"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n"
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q15");
#endif
} else {
#ifdef __aarch64__
asm volatile(
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"1: \n" /* main loop*/
"str q0, [%[doutc0r0]], #16 \n" /* store c2r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2,
c0r3 */
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(doutc0_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0");
#else
asm volatile(
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, "
"c0r1, c0r2, c0r3\n"
"1: @ main loop\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add "
"pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(doutc0_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0");
#endif
}
act_switch_c1_fp32(din_hei_ptr, doutc0_ptr, cnt_loop, act_param);
}
if (we > width) {
if (remain > 0) {
int offset = i * w_round * c1 + c1 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
int j = we - w4;
if (flag_relu) {
for (; j < width; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
din_hei_ptr++;
doutc0_ptr += w4 * cnt;
int j = w4 * cnt;
if (act_param != nullptr && act_param->has_active) {
float six = act_param->Relu_clipped_coef;
float scale = act_param->Leaky_relu_alpha;
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (; j < width; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
din_hei_ptr++;
}
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
for (; j < width; ++j) {
float tmp = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc0_ptr++) = LITEMIN(tmp, six);
din_hei_ptr++;
}
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
for (; j < width; ++j) {
if (din_hei_ptr[0] >= 0) {
*(doutc0_ptr++) = din_hei_ptr[0];
} else {
*(doutc0_ptr++) = din_hei_ptr[0] * scale;
}
din_hei_ptr++;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
for (; j < width; ++j) {
......@@ -725,6 +914,7 @@ inline bool write_to_output_c1_fp32(const float* din,
}
return true;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC2_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \
......@@ -740,6 +930,18 @@ inline bool write_to_output_c1_fp32(const float* din,
"fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ \
"fmax v3.4s, v5.4s, v20.4s \n" /*relu*/
#define NCHWC2_TRANS_FP32_RELU6 \
"fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \
"bif v3.16b, v5.16b, v7.16b \n" /* choose*/
#define NCHWC2_TRANS_FP32_STORE \
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \
\
......@@ -749,8 +951,7 @@ inline bool write_to_output_c1_fp32(const float* din,
"bne 1b \n" /* jump to main loop*/
#else
#define NCHWC2_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " \
"c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, c1r0 \n" \
"vmov.u32 q15, #0 @ dump zero\n" \
"1: @ main loop\n" \
"vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \
......@@ -764,11 +965,21 @@ inline bool write_to_output_c1_fp32(const float* din,
"vmax.f32 q0, q0, q15 @ relu\n" \
"vmax.f32 q1, q1, q15 @ relu\n"
#define NCHWC2_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n"
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n"
#define NCHWC2_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \
"pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " \
"pointer\n" \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \
\
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
\
......@@ -776,6 +987,151 @@ inline bool write_to_output_c1_fp32(const float* din,
\
"bne 1b @ jump to main loop\n"
#endif
// clang-format on
inline void act_switch_c2_fp32(const float* din_ptr,
float* doutc0_ptr,
float* doutc1_ptr,
int cnt_loop,
const operators::ActivationParam* act_param) {
if (act_param != nullptr && act_param->has_active) {
float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha);
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v20");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [six] "w"(six)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v20");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [six] "w"(six)
: "q0", "q1", "q2", "q3", "q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [scale] "w"(scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v20");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(scale)
: "q0",
"q1",
"q2",
"q3",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v20");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
......@@ -791,11 +1147,11 @@ inline bool write_to_output_c2_fp32(const float* din,
int height,
int width,
bool flag_relu,
float* trash_ptr) {
float* trash_ptr,
operators::ActivationParam* act_param) {
if (cs > channel) {
return true;
}
const int c2 = 2;
const int w4 = 4;
......@@ -828,55 +1184,56 @@ inline bool write_to_output_c2_fp32(const float* din,
const float* din_hei_ptr = ptr_din + i * w_round * c2;
if (cnt > 0) {
int cnt_loop = cnt;
if (flag_relu) {
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v20");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU
NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
} else {
#ifdef __aarch64__
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5");
#else
asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
}
act_switch_c2_fp32(
din_hei_ptr, doutc0_ptr, doutc1_ptr, cnt_loop, act_param);
}
if (we > width) {
int offset = i * w_round * c2 + c2 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
doutc0_ptr += w4 * cnt;
doutc1_ptr += w4 * cnt;
int j = we - w4;
if (flag_relu) {
for (; j < width; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
din_hei_ptr += 2;
if (act_param != nullptr && act_param->has_active) {
float six = act_param->Relu_clipped_coef;
float scale = act_param->Leaky_relu_alpha;
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (; j < width; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
din_hei_ptr += 2;
}
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
for (; j < width; ++j) {
float tmp1 = LITEMAX(din_hei_ptr[0], 0.f);
float tmp2 = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc0_ptr++) = LITEMIN(tmp1, six);
*(doutc1_ptr++) = LITEMIN(tmp2, six);
din_hei_ptr += 2;
}
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
for (; j < width; ++j) {
if (din_hei_ptr[0] >= 0) {
*(doutc0_ptr++) = din_hei_ptr[0];
} else {
*(doutc0_ptr++) = din_hei_ptr[0] * scale;
}
if (din_hei_ptr[1] >= 0) {
*(doutc1_ptr++) = din_hei_ptr[1];
} else {
*(doutc1_ptr++) = din_hei_ptr[1] * scale;
}
din_hei_ptr += 2;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
for (; j < width; ++j) {
......@@ -888,7 +1245,7 @@ inline bool write_to_output_c2_fp32(const float* din,
}
return true;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC4_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \
......@@ -912,6 +1269,26 @@ inline bool write_to_output_c2_fp32(const float* din,
"fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \
"fmax v19.4s, v19.4s, v20.4s \n" /*relu*/
#define NCHWC4_TRANS_FP32_RELU6 \
"fmin v16.4s, v16.4s, %[six].4s \n" /* relu6 */ \
"fmin v17.4s, v17.4s, %[six].4s \n" /* relu6 */ \
"fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \
"fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \
"fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \
"fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \
"bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \
"bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \
"bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \
"bif v19.16b, v7.16b, v11.16b \n" /* choose*/
#define NCHWC4_TRANS_FP32_STORE \
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \
......@@ -940,6 +1317,26 @@ inline bool write_to_output_c2_fp32(const float* din,
"vmax.f32 q2, q2, q15 @ relu\n" \
"vmax.f32 q3, q3, q15 @ relu\n"
#define NCHWC4_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6 \n" \
"vmin.f32 q1, q1, %q[six] @ relu6 \n" \
"vmin.f32 q2, q2, %q[six] @ relu6 \n" \
"vmin.f32 q3, q3, %q[six] @ relu6 \n"
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q5, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q6, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q7, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q8, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q9, q0, %q[scale] \n" \
"vmul.f32 q10, q1, %q[scale] \n" \
"vmul.f32 q11, q2, %q[scale] \n" \
"vmul.f32 q12, q3, %q[scale] \n" \
"vbif q0, q9, q5 @ choose \n" \
"vbif q1, q10, q6 @ choose \n" \
"vbif q2, q11, q7 @ choose \n" \
"vbif q3, q12, q8 @ choose \n"
#define NCHWC4_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \
......@@ -953,68 +1350,19 @@ inline bool write_to_output_c2_fp32(const float* din,
\
"bne 1b @ jump to main loop\n"
#endif
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
inline bool write_to_output_c4_fp32(const float* din,
float* dout,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
float* trash_ptr) {
const int c4 = 4;
const int w4 = 4;
const int w_round = we - ws;
const int ch_n = ce - cs;
if (ch_n != 4) {
LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is "
"more than zero";
return false;
}
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
float* doutc1r0 = doutc0r0 + size_c_out;
float* doutc2r0 = doutc1r0 + size_c_out;
float* doutc3r0 = doutc2r0 + size_c_out;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_we = we > width ? width : we;
int cnt = (valid_we - ws) / w4;
int remain = valid_we - ws - cnt * w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
const float* din_hei_ptr = ptr_din + i * w_round * ch_n;
if (cnt > 0) {
int cnt_loop = cnt;
if (flag_relu) {
// clang-format on
inline void act_switch_c4_fp32(const float* din_ptr,
float* doutc0_ptr,
float* doutc1_ptr,
float* doutc2_ptr,
float* doutc3_ptr,
int cnt_loop,
const operators::ActivationParam* act_param) {
if (act_param != nullptr && act_param->has_active) {
float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha);
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_STORE
......@@ -1023,7 +1371,7 @@ inline bool write_to_output_c4_fp32(const float* din,
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
[ptr_din] "+r"(din_ptr)
:
: "v0",
"v1",
......@@ -1052,57 +1400,290 @@ inline bool write_to_output_c4_fp32(const float* din,
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
} else {
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
[ptr_din] "+r"(din_ptr)
: [six] "w"(six)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v16",
"v17",
"v18",
"v19");
"v19",
"v20");
#else
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_hei_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3");
: [six] "w"(six)
: "q0", "q1", "q2", "q3", "q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU
NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [scale] "w"(scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU
NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [scale] "w"(scale)
: "q0",
"q1",
"q2",
"q3",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0",
"v1",
"v2",
"v3",
"v8",
"v9",
"v10",
"v11",
"v16",
"v17",
"v18",
"v19");
#else
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q15");
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w]
*/
inline bool write_to_output_c4_fp32(const float* din,
float* dout,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
float* trash_ptr,
operators::ActivationParam* act_param) {
const int c4 = 4;
const int w4 = 4;
const int w_round = we - ws;
const int ch_n = ce - cs;
if (ch_n != 4) {
LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is "
"more than zero";
return false;
}
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
float* doutc1r0 = doutc0r0 + size_c_out;
float* doutc2r0 = doutc1r0 + size_c_out;
float* doutc3r0 = doutc2r0 + size_c_out;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_we = we > width ? width : we;
int cnt = (valid_we - ws) / w4;
int remain = valid_we - ws - cnt * w4;
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 3:
doutc1_ptr = trash_ptr;
case 2:
doutc2_ptr = trash_ptr;
case 1:
doutc3_ptr = trash_ptr;
default:
break;
}
}
const float* din_hei_ptr = ptr_din + i * w_round * ch_n;
if (cnt > 0) {
int cnt_loop = cnt;
act_switch_c4_fp32(din_hei_ptr,
doutc0_ptr,
doutc1_ptr,
doutc2_ptr,
doutc3_ptr,
cnt_loop,
act_param);
}
if (remain > 0) {
int offset = i * w_round * c4 + c4 * w4 * cnt;
din_hei_ptr = ptr_din + offset;
doutc0_ptr += w4 * cnt;
doutc1_ptr += w4 * cnt;
doutc2_ptr += w4 * cnt;
doutc3_ptr += w4 * cnt;
int j = 0;
if (flag_relu) {
for (; j < remain; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f);
din_hei_ptr += w4;
if (act_param != nullptr && act_param->has_active) {
float six = act_param->Relu_clipped_coef;
float scale = act_param->Leaky_relu_alpha;
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (; j < remain; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f);
din_hei_ptr += 4;
}
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
for (; j < remain; ++j) {
float tmp1 = LITEMAX(din_hei_ptr[0], 0.f);
float tmp2 = LITEMAX(din_hei_ptr[1], 0.f);
float tmp3 = LITEMAX(din_hei_ptr[2], 0.f);
float tmp4 = LITEMAX(din_hei_ptr[3], 0.f);
*(doutc0_ptr++) = LITEMIN(tmp1, six);
*(doutc1_ptr++) = LITEMIN(tmp2, six);
*(doutc2_ptr++) = LITEMIN(tmp3, six);
*(doutc3_ptr++) = LITEMIN(tmp4, six);
din_hei_ptr += 4;
}
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
for (; j < remain; ++j) {
if (din_hei_ptr[0] >= 0) {
*(doutc0_ptr++) = din_hei_ptr[0];
} else {
*(doutc0_ptr++) = din_hei_ptr[0] * scale;
}
if (din_hei_ptr[1] >= 0) {
*(doutc1_ptr++) = din_hei_ptr[1];
} else {
*(doutc1_ptr++) = din_hei_ptr[1] * scale;
}
if (din_hei_ptr[2] >= 0) {
*(doutc2_ptr++) = din_hei_ptr[2];
} else {
*(doutc2_ptr++) = din_hei_ptr[2] * scale;
}
if (din_hei_ptr[3] >= 0) {
*(doutc3_ptr++) = din_hei_ptr[3];
} else {
*(doutc3_ptr++) = din_hei_ptr[3] * scale;
}
din_hei_ptr += 4;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
for (; j < remain; ++j) {
......@@ -1110,14 +1691,14 @@ inline bool write_to_output_c4_fp32(const float* din,
*(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2];
*(doutc3_ptr++) = din_hei_ptr[3];
din_hei_ptr += w4;
din_hei_ptr += 4;
}
}
}
}
return true;
}
// clang-format off
#ifdef __aarch64__
#define NCHWC8_TRANS_FP32_COMPUTE \
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \
......@@ -1161,6 +1742,48 @@ inline bool write_to_output_c4_fp32(const float* din,
"fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, v20.4s \n" /*relu*/
#define NCHWC8_TRANS_FP32_RELU6 \
"fmin v16.4s, v16.4s, %[six].4s \n" /*relu6*/ \
"fmin v17.4s, v17.4s, %[six].4s \n" /*relu6*/ \
"fmin v18.4s, v18.4s, %[six].4s \n" /*relu6*/ \
"fmin v19.4s, v19.4s, %[six].4s \n" /*relu6*/ \
\
"fmin v8.4s, v8.4s, %[six].4s \n" /*relu6*/ \
"fmin v9.4s, v9.4s, %[six].4s \n" /*relu6*/ \
"fmin v12.4s, v12.4s, %[six].4s \n" /*relu6*/ \
"fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
\
"cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \
\
"fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \
"fmul v27.4s, v18.4s, %[scale].4s \n" /* mul */ \
"fmul v28.4s, v19.4s, %[scale].4s \n" /* mul */ \
\
"fmul v29.4s, v8.4s, %[scale].4s \n" /* mul */ \
"fmul v30.4s, v9.4s, %[scale].4s \n" /* mul */ \
"fmul v31.4s, v12.4s, %[scale].4s \n" /* mul */ \
\
"bif v16.16b, v25.16b, v10.16b \n" /* choose*/ \
"bif v17.16b, v26.16b, v11.16b \n" /* choose*/ \
"bif v18.16b, v27.16b, v14.16b \n" /* choose*/ \
"bif v19.16b, v28.16b, v15.16b \n" /* choose*/ \
"fmul v25.4s, v13.4s, %[scale].4s \n" /* mul */ \
\
"bif v8.16b, v29.16b, v21.16b \n" /* choose*/ \
"bif v9.16b, v30.16b, v22.16b \n" /* choose*/ \
"bif v12.16b, v31.16b, v23.16b \n" /* choose*/ \
"bif v13.16b, v25.16b, v24.16b \n" /* choose*/
#define NCHWC8_TRANS_FP32_STORE \
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \
......@@ -1174,6 +1797,7 @@ inline bool write_to_output_c4_fp32(const float* din,
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ \
\
"bne 1b \n" /* jump to main loop*/
#else
#define NCHWC8_TRANS_FP32_COMPUTE \
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \
......@@ -1203,6 +1827,48 @@ inline bool write_to_output_c4_fp32(const float* din,
"vmax.f32 q6, q6, q15 @ relu\n" \
"vmax.f32 q7, q7, q15 @ relu\n"
#define NCHWC8_TRANS_FP32_RELU6 \
"vmin.f32 q0, q0, %q[six] @ relu6\n" \
"vmin.f32 q1, q1, %q[six] @ relu6\n" \
"vmin.f32 q2, q2, %q[six] @ relu6\n" \
"vmin.f32 q3, q3, %q[six] @ relu6\n" \
\
"vmin.f32 q4, q4, %q[six] @ relu6\n" \
"vmin.f32 q5, q5, %q[six] @ relu6\n" \
"vmin.f32 q6, q6, %q[six] @ relu6\n" \
"vmin.f32 q7, q7, %q[six] @ relu6\n"
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q9, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q10, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q11, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q12, q3, q15 @ q0 > 0 \n" \
"vmul.f32 q13, q0, %q[scale] \n" \
"vmul.f32 q14, q1, %q[scale] \n" \
"vmul.f32 q15, q2, %q[scale] \n" \
\
"vbif q0, q13, q9 @ choose \n" \
"vmul.f32 q9, q3, %q[scale] \n" \
\
"vbif q1, q14, q10 @ choose \n" \
"vbif q2, q15, q11 @ choose \n" \
"vbif q3, q9, q12 @ choose \n" \
\
"vcge.f32 q9, q4, q15 @ q0 > 0 \n" \
"vcge.f32 q10, q5, q15 @ q0 > 0 \n" \
"vcge.f32 q11, q6, q15 @ q0 > 0 \n" \
"vcge.f32 q12, q7, q15 @ q0 > 0 \n" \
"vmul.f32 q13, q4, %q[scale] \n" \
"vmul.f32 q14, q5, %q[scale] \n" \
"vmul.f32 q15, q6, %q[scale] \n" \
\
"vbif q4, q13, q9 @ choose \n" \
"vmul.f32 q9, q7, %q[scale] \n" \
\
"vbif q5, q14, q10 @ choose \n" \
"vbif q6, q15, q11 @ choose \n" \
"vbif q7, q9, q12 @ choose \n"
#define NCHWC8_TRANS_FP32_STORE \
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \
......@@ -1232,84 +1898,23 @@ inline bool write_to_output_c4_fp32(const float* din,
"bne 1b @ jump to main loop\n"
#endif
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
inline bool write_to_output_c8_fp32(const float* din,
float* dout,
int ch_n,
int hei_n,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
float* trash_ptr) {
if (ch_n != 8 || hei_n <= 0) {
LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero";
return false;
}
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
float* doutc1r0 = doutc0r0 + size_c_out;
float* doutc2r0 = doutc1r0 + size_c_out;
float* doutc3r0 = doutc2r0 + size_c_out;
float* doutc4r0 = doutc3r0 + size_c_out;
float* doutc5r0 = doutc4r0 + size_c_out;
float* doutc6r0 = doutc5r0 + size_c_out;
float* doutc7r0 = doutc6r0 + size_c_out;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws;
int cnt = valid_w / 4;
if (we > width) {
cnt--;
}
if (flag_relu) {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
float* doutc4_ptr = doutc4r0 + size_w;
float* doutc5_ptr = doutc5r0 + size_w;
float* doutc6_ptr = doutc6r0 + size_w;
float* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const float* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
// clang-format on
inline void act_switch_c8_fp32(const float* din_ptr,
float* doutc0_ptr,
float* doutc1_ptr,
float* doutc2_ptr,
float* doutc3_ptr,
float* doutc4_ptr,
float* doutc5_ptr,
float* doutc6_ptr,
float* doutc7_ptr,
int cnt_loop,
const operators::ActivationParam* act_param) {
if (act_param != nullptr && act_param->has_active) {
float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha);
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU
NCHWC8_TRANS_FP32_STORE
......@@ -1322,9 +1927,10 @@ inline bool write_to_output_c8_fp32(const float* din,
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
[ptr_din] "+r"(din_ptr)
:
: "v1",
: "v0",
"v1",
"v2",
"v3",
"v4",
......@@ -1338,7 +1944,6 @@ inline bool write_to_output_c8_fp32(const float* din,
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
......@@ -1355,66 +1960,17 @@ inline bool write_to_output_c8_fp32(const float* din,
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q15");
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15");
#endif
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
for (; i < width; ++i) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f);
*(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f);
*(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f);
*(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f);
*(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f);
din_hei_ptr += 8;
}
}
}
} else {
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
float* doutc4_ptr = doutc4r0 + size_w;
float* doutc5_ptr = doutc5r0 + size_w;
float* doutc6_ptr = doutc6r0 + size_w;
float* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
}
ptr_din = din + i * valid_w * ch_n;
const float* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU6
NCHWC8_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
......@@ -1424,8 +1980,8 @@ inline bool write_to_output_c8_fp32(const float* din,
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_hei_ptr)
:
[ptr_din] "+r"(din_ptr)
: [six] "w"(six)
: "v0",
"v1",
"v2",
......@@ -1441,14 +1997,29 @@ inline bool write_to_output_c8_fp32(const float* din,
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE
asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU
NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
: [six] "w"(six)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU
NCHWC8_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
......@@ -1457,16 +2028,323 @@ inline bool write_to_output_c8_fp32(const float* din,
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_hei_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
: [scale] "w"(scale)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"v26",
"v27",
"v28",
"v29",
"v30",
"v31");
#else
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU
NCHWC8_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4");
: [scale] "w"(scale)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[cnt] "+r"(cnt_loop),
[ptr_din] "+r"(din_ptr)
:
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE
: [doutc0r0] "+r"(doutc0_ptr),
[doutc1r0] "+r"(doutc1_ptr),
[doutc2r0] "+r"(doutc2_ptr),
[doutc3r0] "+r"(doutc3_ptr),
[doutc4r0] "+r"(doutc4_ptr),
[doutc5r0] "+r"(doutc5_ptr),
[doutc6r0] "+r"(doutc6_ptr),
[doutc7r0] "+r"(doutc7_ptr),
[ptr_din] "+r"(din_ptr),
[cnt] "+r"(cnt_loop)
:
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15");
#endif
}
}
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
inline bool write_to_output_c8_fp32(const float* din,
float* dout,
int ch_n,
int hei_n,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int height,
int width,
bool flag_relu,
float* trash_ptr,
operators::ActivationParam* act_param) {
if (ch_n != 8 || hei_n <= 0) {
LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero";
return false;
}
int size_c_out = width * height;
float* doutc0r0 = dout + cs * size_c_out + hs * width + ws;
float* doutc1r0 = doutc0r0 + size_c_out;
float* doutc2r0 = doutc1r0 + size_c_out;
float* doutc3r0 = doutc2r0 + size_c_out;
float* doutc4r0 = doutc3r0 + size_c_out;
float* doutc5r0 = doutc4r0 + size_c_out;
float* doutc6r0 = doutc5r0 + size_c_out;
float* doutc7r0 = doutc6r0 + size_c_out;
const float* ptr_din = din;
int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int valid_w = we - ws;
int w4 = 4;
int cnt = valid_w / 4;
if (we > width) {
cnt--;
}
for (int i = 0; i < size_h; i++) {
int size_w = i * width;
float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width;
float* doutc1_ptr = doutc1r0 + size_w;
float* doutc2_ptr = doutc2r0 + size_w;
float* doutc3_ptr = doutc3r0 + size_w;
float* doutc4_ptr = doutc4r0 + size_w;
float* doutc5_ptr = doutc5r0 + size_w;
float* doutc6_ptr = doutc6r0 + size_w;
float* doutc7_ptr = doutc7r0 + size_w;
if (ce > channel) {
switch (ce - channel) {
case 7:
doutc1_ptr = trash_ptr;
case 6:
doutc2_ptr = trash_ptr;
case 5:
doutc3_ptr = trash_ptr;
case 4:
doutc4_ptr = trash_ptr;
case 3:
doutc5_ptr = trash_ptr;
case 2:
doutc6_ptr = trash_ptr;
case 1:
doutc7_ptr = trash_ptr;
default:
break;
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
int i = we - 4;
}
ptr_din = din + i * valid_w * ch_n;
const float* din_hei_ptr = ptr_din;
if (cnt > 0) {
int cnt_loop = cnt;
act_switch_c8_fp32(din_hei_ptr,
doutc0_ptr,
doutc1_ptr,
doutc2_ptr,
doutc3_ptr,
doutc4_ptr,
doutc5_ptr,
doutc6_ptr,
doutc7_ptr,
cnt_loop,
act_param);
}
if (we > width) {
int offset = 32 * (valid_w / 4 - 1);
din_hei_ptr = ptr_din + offset;
doutc0_ptr += w4 * cnt;
doutc1_ptr += w4 * cnt;
doutc2_ptr += w4 * cnt;
doutc3_ptr += w4 * cnt;
doutc4_ptr += w4 * cnt;
doutc5_ptr += w4 * cnt;
doutc6_ptr += w4 * cnt;
doutc7_ptr += w4 * cnt;
int i = we - 4;
if (act_param != nullptr && act_param->has_active) {
float six = act_param->Relu_clipped_coef;
float scale = act_param->Leaky_relu_alpha;
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (; i < width; ++i) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
*(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f);
*(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f);
*(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f);
*(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f);
*(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f);
din_hei_ptr += 8;
}
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
for (; i < width; ++i) {
float tmp1 = LITEMAX(din_hei_ptr[0], 0.f);
float tmp2 = LITEMAX(din_hei_ptr[1], 0.f);
float tmp3 = LITEMAX(din_hei_ptr[2], 0.f);
float tmp4 = LITEMAX(din_hei_ptr[3], 0.f);
float tmp5 = LITEMAX(din_hei_ptr[4], 0.f);
float tmp6 = LITEMAX(din_hei_ptr[5], 0.f);
float tmp7 = LITEMAX(din_hei_ptr[6], 0.f);
float tmp8 = LITEMAX(din_hei_ptr[7], 0.f);
*(doutc0_ptr++) = LITEMIN(tmp1, six);
*(doutc1_ptr++) = LITEMIN(tmp2, six);
*(doutc2_ptr++) = LITEMIN(tmp3, six);
*(doutc3_ptr++) = LITEMIN(tmp4, six);
*(doutc4_ptr++) = LITEMIN(tmp5, six);
*(doutc5_ptr++) = LITEMIN(tmp6, six);
*(doutc6_ptr++) = LITEMIN(tmp7, six);
*(doutc7_ptr++) = LITEMIN(tmp8, six);
din_hei_ptr += 8;
}
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
for (; i < width; ++i) {
if (din_hei_ptr[0] >= 0) {
*(doutc0_ptr++) = din_hei_ptr[0];
} else {
*(doutc0_ptr++) = din_hei_ptr[0] * scale;
}
if (din_hei_ptr[1] >= 0) {
*(doutc1_ptr++) = din_hei_ptr[1];
} else {
*(doutc1_ptr++) = din_hei_ptr[1] * scale;
}
if (din_hei_ptr[2] >= 0) {
*(doutc2_ptr++) = din_hei_ptr[2];
} else {
*(doutc2_ptr++) = din_hei_ptr[2] * scale;
}
if (din_hei_ptr[3] >= 0) {
*(doutc3_ptr++) = din_hei_ptr[3];
} else {
*(doutc3_ptr++) = din_hei_ptr[3] * scale;
}
if (din_hei_ptr[4] >= 0) {
*(doutc4_ptr++) = din_hei_ptr[4];
} else {
*(doutc4_ptr++) = din_hei_ptr[4] * scale;
}
if (din_hei_ptr[4] >= 0) {
*(doutc5_ptr++) = din_hei_ptr[5];
} else {
*(doutc5_ptr++) = din_hei_ptr[5] * scale;
}
if (din_hei_ptr[6] >= 0) {
*(doutc6_ptr++) = din_hei_ptr[6];
} else {
*(doutc6_ptr++) = din_hei_ptr[6] * scale;
}
if (din_hei_ptr[7] >= 0) {
*(doutc7_ptr++) = din_hei_ptr[7];
} else {
*(doutc7_ptr++) = din_hei_ptr[7] * scale;
}
din_hei_ptr += 8;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
} else {
for (; i < width; ++i) {
*(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1];
......
......@@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_3x3s2_depthwise_fp32(const float* i_data,
......@@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din,
......
......@@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din,
ARMContext* ctx,
const float* scale) {
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
int stride = param.strides[1];
......@@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
......@@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
}
......
......@@ -67,7 +67,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
no_dilation && pads_all_equal) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
......
......@@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size,
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......
......@@ -137,34 +137,6 @@ class ConvOpLite : public OpLite {
std::string padding_algorithm_{""};
};
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -59,6 +59,8 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
......@@ -118,6 +120,13 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
param.dilations = std::make_shared<std::vector<int>>(dilas);
param.fuse_relu = flag_relu;
param.groups = group;
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type =
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu
param.activation_param = act_param;
}
param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat));
......@@ -243,6 +252,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
......@@ -255,6 +265,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册