提交 42bbd157 编写于 作者: Y yiicy 提交者: GitHub

[ARM] 5x5dw and sgemv support fuse activation, test=develop (#2797)

* refactor 5x5s1 dw conv armv8, test=develop

* [ARM] refactor depthwise conv 5x5s1, and support relu6, leakey relu, test=develop

* [ARM] sgemv support fuse relu6 and leakey relu,test=develop

* [ARM] reduce some conv ut case, test=develop

* [ARM] fix 5x5dw conv pick kernel bug, test=develop

* fix code style, test=develop

* [ARM] fix sgemv fuse relu6 bug, test=develop

* [ARM] fix fp32 5x5s1 dw bug, test=develop

* [ARM] fix fp32 5x5 dw conv pick kernel bug, test=develop
上级 7a0e3fd7
......@@ -614,11 +614,11 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"blt 3f \n"
#define LEFT_RESULT_S1_LEAKY_RELU \
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
......@@ -639,8 +639,8 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \
"cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
"fcmge v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
......@@ -657,10 +657,10 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \
"fcmge v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \
"cmp %w[cnt], #1 \n" \
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
......@@ -802,7 +802,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define MID_RESULT_S1_LEAKY_RELU \
"movi v21.4s, #0 \n" \
"cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \
"fcmge v18.4s, v12.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -824,7 +824,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \
"fcmge v18.4s, v13.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -846,7 +846,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \
"fcmge v18.4s, v14.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -861,7 +861,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \
"fcmge v18.4s, v15.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
......@@ -980,7 +980,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define RIGHT_RESULT_S1_LEAKY_RELU \
"movi v1.4s, #0 \n" \
"cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \
"fcmge v20.4s, v12.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
......@@ -999,7 +999,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \
"fcmge v20.4s, v13.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
......@@ -1017,7 +1017,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \
"fcmge v20.4s, v14.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
\
......@@ -1028,7 +1028,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\
"bif v14.16b, v24.16b, v18.16b \n" \
\
"cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \
"fcmge v20.4s, v15.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
......@@ -1128,18 +1128,18 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_LEAKY_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\
"bif v12.16b, v20.16b, v18.16b \n" \
"bif v13.16b, v21.16b, v19.16b \n" \
"st1 {v12.4s}, [%[out1]]\n" \
#define RESULT_S_S1_LEAKY_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\
"bif v12.16b, v20.16b, v18.16b \n" \
"bif v13.16b, v21.16b, v19.16b \n" \
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"prfm pldl1keep, [%[din0]]\n" \
......
......@@ -179,11 +179,11 @@ namespace math {
#define LEAKY_RELU \
"movi v0.4s, #0\n" /* for relu */ \
"ldr x0, [%[outl], #88]\n" \
"cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v1.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ \
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \
"ld1 {v9.4s}, [x0] \n" \
"cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v3.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \
"fcmge v4.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \
"ldr x0, [%[outl]] \n" \
"fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \
......@@ -193,10 +193,10 @@ namespace math {
"bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \
"bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"fcmge v3.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"fcmge v4.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \
"fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \
......
......@@ -50,7 +50,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias;
int pad_h = paddings[0];
int pad_w = paddings[1];
int pad_w = paddings[2];
const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4;
......@@ -477,7 +477,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias;
int pad_h = paddings[0];
int pad_w = paddings[1];
int pad_w = paddings[2];
const int threads = ctx->threads();
//! set 1/4 l2 cache
int llc_size = ctx->llc_size() / 4;
......
......@@ -451,44 +451,44 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"blt 1f \n"
#define LEFT_RESULT_S2_LEAKY_RELU \
"ld1 {v22.4s}, [%[scale_ptr]] \n" \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
#define LEFT_RESULT_S2_LEAKY_RELU \
"ld1 {v22.4s}, [%[scale_ptr]] \n" \
"fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define MID_RESULT_S2_RELU \
......@@ -542,30 +542,30 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
#define MID_RESULT_S2_LEAKY_RELU \
"fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
......@@ -606,25 +606,25 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
#define RIGHT_RESULT_S2_LEAKY_RELU \
"fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define COMPUTE_S_S2 \
......
......@@ -104,13 +104,13 @@ namespace math {
"fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -709,7 +709,6 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
"q15");
#endif
// clang-format on
int32_t* ptr_tmp = ptr_out0 - w_loop * 32;
block_inr0 = block_inr1;
block_inr1 = block_inr2;
block_inr2 = block_inr3;
......
......@@ -198,24 +198,24 @@ namespace math {
"fmin v20.4s, v20.4s, %[vsix].4s\n" \
"fmin v21.4s, v21.4s, %[vsix].4s\n" \
"fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \
"str q21, [%[outc2]], #16\n" \
#define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \
"fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \
"str q21, [%[outc2]], #16\n" \
"str q22, [%[outc3]], #16\n"
#else
......
......@@ -614,16 +614,16 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC1_TRANS_FP32_LEAKY_RELU \
"cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \
"fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \
"fcmge v4.4s, v0.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v5.4s, v1.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \
"fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \
"fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \
"bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \
"bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v7.16b \n" /* choose*/
......@@ -674,15 +674,15 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
"vbif q3, q12, q8 @ choose \n"
#define NCHWC1_TRANS_FP32_STORE \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \
"subs %[cnt], %[cnt], #1 @ loop count - 1\n" \
\
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \
"vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \
"vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \
"vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \
\
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \
"vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \
\
"bne 1b @ jump to main loop\n"
#endif
......@@ -934,12 +934,12 @@ inline bool write_to_output_c1_fp32(const float* din,
"fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \
#define NCHWC2_TRANS_FP32_LEAKY_RELU \
"fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \
"bif v3.16b, v5.16b, v7.16b \n" /* choose*/
#define NCHWC2_TRANS_FP32_STORE \
......@@ -1275,19 +1275,19 @@ inline bool write_to_output_c2_fp32(const float* din,
"fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \
"fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \
"fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \
"fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \
"bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \
"bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \
"bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \
"bif v19.16b, v7.16b, v11.16b \n" /* choose*/
#define NCHWC4_TRANS_FP32_LEAKY_RELU \
"fcmge v8.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v9.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v10.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \
"fcmge v11.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \
"fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \
"fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \
"bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \
"bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \
"bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \
"bif v19.16b, v7.16b, v11.16b \n" /* choose*/
#define NCHWC4_TRANS_FP32_STORE \
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \
......@@ -1754,15 +1754,15 @@ inline bool write_to_output_c4_fp32(const float* din,
"fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
\
"cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \
"fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \
\
"fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \
......@@ -1839,7 +1839,7 @@ inline bool write_to_output_c4_fp32(const float* din,
"vmin.f32 q7, q7, %q[six] @ relu6\n"
#define NCHWC8_TRANS_FP32_LEAKY_RELU \
"vcge.f32 q9, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q9, q0, q15 @ q0 > 0 \n" \
"vcge.f32 q10, q1, q15 @ q0 > 0 \n" \
"vcge.f32 q11, q2, q15 @ q0 > 0 \n" \
"vcge.f32 q12, q3, q15 @ q0 > 0 \n" \
......@@ -2168,19 +2168,19 @@ inline void act_switch_c8_fp32(const float* din_ptr,
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define DO_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define DO_LEAKY_RELU \
"fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define DO_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
......@@ -2217,7 +2217,7 @@ inline void act_switch_c8_fp32(const float* din_ptr,
"vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \
"vbif q6, q13, q13 @ choose \n"
"vbif q6, q14, q13 @ choose \n"
#define DO_STORE \
"subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
......
......@@ -123,20 +123,21 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int padh,
ARMContext* ctx);
void conv_depthwise_5x5s1_fp32(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
void conv_depthwise_5x5s1_fp32(float* dout,
const float* din,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
const operators::ConvParam& param,
ARMContext* ctx);
void conv_depthwise_5x5s2_fp32(const float* din,
......
......@@ -188,7 +188,6 @@ void conv1x1s1_gemm(const float* i_data,
if (n > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
}
//! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) {
// dC
......@@ -210,8 +209,11 @@ void conv1x1s1_gemm(const float* i_data,
k,
flag_bias,
bias_group,
flag_relu,
ctx);
act_param.has_active,
act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else {
sgemm_prepack(false,
m,
......@@ -410,8 +412,11 @@ void conv_im2col_gemm(const float* i_data,
k,
flag_bias,
bias_group,
flag_relu,
ctx);
act_param.has_active,
act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else {
int ldb = n;
sgemm_prepack(false,
......@@ -677,7 +682,8 @@ void conv_depthwise_5x5_fp32(const void* din,
const float* scale) {
auto paddings = *param.paddings;
auto act_param = param.activation_param;
int pad = paddings[0];
int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
......@@ -698,20 +704,21 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param,
ctx);
} else if (stride == 1) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
pad,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
param,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv";
......
......@@ -136,19 +136,19 @@ void fill_bias_relu<int>(int* tensor,
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define FILL_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define FILL_LEAKY_RELU \
"fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define FILL_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
......
......@@ -22,35 +22,87 @@ namespace lite {
namespace arm {
namespace math {
void sgemv(const bool transA,
const int M,
void sgemv(const int M,
const int N,
const float *A,
const float *x,
float *y);
void sgemv_relu(const bool transA,
const int M,
const int N,
const float *A,
const float *x,
float *y);
float *y,
bool flag_bias,
const float *bias);
void sgemv_bias(const bool transA,
const int M,
void sgemv_relu(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias);
void sgemv_bias_relu(const bool transA,
const int M,
const int N,
const float *A,
const float *x,
float *y,
const float *bias);
void sgemv_relu6(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
const float six);
void sgemv_leakey_relu(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
const float alpha);
void sgemv_trans(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
bool flag_act,
lite_api::ActivationType act,
const ARMContext *ctx,
float six,
float alpha);
bool sgemv(const float *A,
const float *x,
float *y,
bool transA,
int M,
int N,
bool is_bias,
const float *bias,
bool flag_act,
lite_api::ActivationType act,
const ARMContext *ctx,
float six,
float alpha) {
if (transA) {
sgemv_trans(M, N, A, x, y, is_bias, bias, flag_act, act, ctx, six, alpha);
} else {
if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
sgemv_relu(M, N, A, x, y, is_bias, bias);
} else if (act == lite_api::ActivationType::kRelu6) {
sgemv_relu6(M, N, A, x, y, is_bias, bias, six);
} else if (act == lite_api::ActivationType::kLeakyRelu) {
sgemv_leakey_relu(M, N, A, x, y, is_bias, bias, alpha);
} else {
LOG(FATAL)
<< "sgemv no transA only support relu, relu6, leakey relu fusion";
}
} else {
sgemv(M, N, A, x, y, is_bias, bias);
}
}
return true;
}
#ifdef __aarch64__
void sgemv_trans(const int M,
const int N,
......@@ -59,8 +111,11 @@ void sgemv_trans(const int M,
float *y,
bool flag_bias,
const float *bias,
bool flag_relu,
const ARMContext *ctx) {
bool flag_act,
lite_api::ActivationType act,
const ARMContext *ctx,
float six,
float alpha) {
int m_cnt16 = M >> 4;
int m_cnt8 = (M & 15) >> 3;
int m_cnt4 = (M & 15 & 7) >> 2;
......@@ -281,26 +336,70 @@ void sgemv_trans(const int M,
valid_ths = rdc_ths;
rdc_ths = rdc_ths >> 1;
}
if (flag_relu) {
if (flag_act) {
float *in_y = y_buf;
float32x4_t vzero = vdupq_n_f32(0.f);
if (cnt4 > 0) {
int cnt = cnt4;
asm volatile(
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"1:\n"
"fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "v0", "v1", "cc", "memory");
}
for (int r = 0; r < remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
if (act == lite_api::ActivationType::kRelu) {
if (cnt4 > 0) {
int cnt = cnt4;
asm volatile(
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"1:\n"
"fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "v0", "v1", "cc", "memory");
}
for (int r = 0; r < remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
}
} else if (act == lite_api::ActivationType::kRelu6) {
float32x4_t vsix = vdupq_n_f32(six);
if (cnt4 > 0) {
int cnt = cnt4;
asm volatile(
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"1:\n"
"fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu6 */
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* v1 relu6 */
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "v0", "v1", "cc", "memory");
}
for (int r = 0; r < remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
y[r] = y[r] > six ? six : y[r];
}
} else if (act == lite_api::ActivationType::kLeakyRelu) {
float32x4_t valpha = vdupq_n_f32(alpha);
if (cnt4 > 0) {
int cnt = cnt4;
asm volatile(
"1:\n"
"ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */
"fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */
"fmul v5.4s, v0.4s, %[valpha].4s \n" /* vmulq_f32 */
"bif v0.16b, v5.16b, v4.16b \n" /* choose */
"subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */
"st1 {v0.4s}, [%[out_y]], #16 \n" /* store v0 to y */
"bne 1b \n" /* branch to label 1*/
: [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero), [valpha] "w"(valpha)
: "v0", "v4", "v5", "cc", "memory");
}
for (int r = 0; r < remain; ++r) {
y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r];
}
}
} else {
memcpy(y, y_buf, M * sizeof(float));
......@@ -314,8 +413,11 @@ void sgemv_trans(const int M,
float *y,
bool flag_bias,
const float *bias,
bool flag_relu,
const ARMContext *ctx) {
bool flag_act,
lite_api::ActivationType act,
const ARMContext *ctx,
float six,
float alpha) {
int m_cnt8 = M >> 3;
int m_cnt4 = (M & 7) >> 2;
int m_remain = M & 7 & 3;
......@@ -497,43 +599,73 @@ void sgemv_trans(const int M,
valid_ths = rdc_ths;
rdc_ths = rdc_ths >> 1;
}
if (flag_relu) {
// do activation
if (flag_act) {
float *in_y = y_buf;
float32x4_t vzero = vdupq_n_f32(0.f);
if (m_cnt8 > 0) {
int cnt8 = m_cnt8;
asm volatile(
"vld1.32 {d0-d3}, [%[in_y]]! \n" /* load y to q0, q1 */
"1:\n"
"vmax.f32 q2, q0, %q[vzero] \n" /* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"vmax.f32 q3, q1, %q[vzero] \n" /* q1 relu */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d4-d7}, [%[out_y]]! \n" /* store q0, q1 to y*/
"vld1.32 {d2-d3}, [%[in_y]]! \n" /* load y to q0 */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #32 \n" /* restore in_y */
: [cnt] "+r"(cnt8), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "q0", "q1", "q2", "q3", "cc", "memory");
}
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"1:\n"
"vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
for (int r = 0; r < m_remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
m_cnt4 = M >> 2;
m_remain = M & 3;
if (act == lite_api::ActivationType::kRelu) {
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"1:\n"
"vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero)
: "q0", "q1", "cc", "memory");
}
for (int r = 0; r < m_remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
}
} else if (act == lite_api::ActivationType::kRelu6) {
float32x4_t vsix = vdupq_n_f32(six);
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"1:\n"
"vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu6 */
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"vmin.f32 q1, q1, %q[vsix] \n" /* q0 relu6 */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */
"bne 1b \n" /* branch to label 1*/
"sub %[in_y], %[in_y], #16 \n" /* restore in_y */
: [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "q0", "q1", "cc", "memory");
}
for (int r = 0; r < m_remain; ++r) {
y[r] = in_y[r] > 0.f ? in_y[r] : 0.f;
y[r] = y[r] > six ? six : y[r];
}
} else if (act == lite_api::ActivationType::kLeakyRelu) {
float32x4_t valpha = vdupq_n_f32(alpha);
if (m_cnt4 > 0) {
int cnt4 = m_cnt4;
asm volatile(
"1:\n"
"vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */
"vcge.f32 q3, q0, %q[vzero] \n" /* vcgeq_f32 */
"vmul.f32 q4, q0, %q[valpha] \n" /* vmulq_f32 */
"vbif q0, q4, q3 \n" /* choose */
"subs %[cnt], %[cnt], #1 \n" /* sub cnt */
"vst1.32 {d0-d1}, [%[out_y]]! \n" /* store q0 to y */
"bne 1b \n" /* branch to label 1*/
: [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y)
: [vzero] "w"(vzero), [valpha] "w"(valpha)
: "q0", "q3", "q4", "cc", "memory");
}
for (int r = 0; r < m_remain; ++r) {
y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r];
}
}
} else {
memcpy(y, y_buf, M * sizeof(float));
......@@ -541,41 +673,6 @@ void sgemv_trans(const int M,
}
#endif // __aarch64__
bool sgemv(const float *A,
const float *x,
float *y,
bool transA,
int M,
int N,
bool is_bias,
const float *bias,
bool is_relu,
const ARMContext *ctx) {
if (transA) {
sgemv_trans(M, N, A, x, y, is_bias, bias, is_relu, ctx);
} else {
if (is_bias) {
//! with bias
if (is_relu) {
//! with relu
sgemv_bias_relu(transA, M, N, A, x, y, bias);
} else {
//! without relu
sgemv_bias(transA, M, N, A, x, y, bias);
}
} else {
//! without bias
if (is_relu) {
//! with relu
sgemv_relu(transA, M, N, A, x, y);
} else {
//! without relu
sgemv(transA, M, N, A, x, y);
}
}
}
return true;
}
// clang-format off
//! define compute kernel
#ifdef __aarch64__
......@@ -715,19 +812,19 @@ bool sgemv(const float *A,
#define SGEMV_KERNEL_1 \
/* check main loop */ \
"cmp %w[cnt], #1 \n" /* check whether has main loop */ \
"blt 2f \n" /* jump to tail */ /* main loop */ \
"1: \n" /* main loop */ \
"ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \
"ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \
"fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \
"subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \
"fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \
"blt 2f \n" /* jump to tail */ \
"1: \n" /* main loop */ \
"ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \
"ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \
"fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \
"subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \
"fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \
"bne 1b \n" /* jump to main loop */ \
/* pair add to final result */ \
"2: \n" /* reduce to scale */ \
"fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \
"faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \
"faddp s8, v10.2s \n" /* pair add to scale */ /* check tails */ \
"faddp s8, v10.2s \n" /* pair add to scale */ \
"cmp %w[tail], #1 \n" /* check whether has tail */ \
"blt 4f \n" /* jump to end */ \
"3: \n" /* tail loop */ \
......@@ -737,43 +834,100 @@ bool sgemv(const float *A,
"subs %w[tail], %w[tail], #1\n" /* sub tail loop count */ \
"bne 3b \n" /* jump to tail loop */
#define SGEMV_OUT_8 \
/* end */ \
"4: \n" /* end */ \
"stp s8, s9, [%[out]] \n" /* save result */ \
"stp s10, s11, [%[out], #8] \n" /* save result */ \
"stp s12, s13, [%[out], #16]\n" /* save result */ \
"stp s14, s15, [%[out], #24]\n" /* save result */
#define SGEMV_OUT_8 \
/* end */ \
"4: \n" /* end */ \
"mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \
"mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \
"mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \
"mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \
"mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \
"mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \
"mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \
"stp q8, q9, [%[out]] \n" /* save result */
#define SGEMV_OUT_8_RELU \
/* end */ \
"4: \n" /* end */ \
"movi d0, #0 \n" /* zero data for relu */ \
"fmax s8, s8, s0 \n" /* relu */ \
"fmax s9, s9, s0 \n" /* relu */ \
"fmax s10, s10, s0 \n" /* relu */ \
"fmax s11, s11, s0 \n" /* relu */ \
"fmax s12, s12, s0 \n" /* relu */ \
"fmax s13, s13, s0 \n" /* relu */ \
"fmax s14, s14, s0 \n" /* relu */ \
"fmax s15, s15, s0 \n" /* relu */ \
"stp s8, s9, [%[out]] \n" /* save result */ \
"stp s10, s11, [%[out], #8] \n" /* save result */ \
"stp s12, s13, [%[out], #16]\n" /* save result */ \
"stp s14, s15, [%[out], #24]\n" /* save result */
"4: \n" /* end */ \
"mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \
"mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \
"mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \
"mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \
"mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \
"mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \
"mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \
"movi v2.4s, #0 \n" /* zero data for relu */\
"fmax v8.4s, v8.4s, v2.4s \n" /* relu */ \
"fmax v9.4s, v9.4s, v2.4s \n" /* relu */ \
"stp q8, q9, [%[out]] \n" /* save result */
#define SGEMV_OUT_1 \
/* end */ \
"4: \n" /* end */ \
#define SGEMV_OUT_8_RELU6 \
/* end */ \
"4: \n" /* end */ \
"mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \
"mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \
"mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \
"mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \
"mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \
"mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \
"mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \
"movi v2.4s, #0 \n" /* zero data for relu6 */\
"fmax v8.4s, v8.4s, v2.4s \n" /* relu6 */ \
"fmax v9.4s, v9.4s, v2.4s \n" /* relu6 */ \
"fmin v8.4s, v8.4s, %[vsix].4s \n" /* relu */ \
"fmin v9.4s, v9.4s, %[vsix].4s \n" /* relu */ \
"stp q8, q9, [%[out]] \n" /* save result */
#define SGEMV_OUT_8_LEAKEY_RELU \
/* end */ \
"4: \n" /* end */ \
"mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \
"mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \
"mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \
"mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \
"mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \
"mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \
"mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \
"movi v2.4s, #0 \n" /* zero data for leakey relu */ \
"fcmge v4.4s, v8.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v8.4s, %[valpha].4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v9.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v9.4s, %[valpha].4s \n" /* vmulq_f32 */ \
"bif v8.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v9.16b, v7.16b, v6.16b \n" /* choose*/ \
"stp q8, q9, [%[out]] \n" /* save result */
#define SGEMV_OUT_1 \
/* end */ \
"4: \n" /* end */ \
"str s8, [%[out]] \n" /* save result */
#define SGEMV_OUT_1_RELU \
/* end */ \
"4: \n" /* end */ \
"movi d0, #0 \n" /* zero data for relu */ \
"fmax s8, s8, s0 \n" /* relu */ \
"movi d1, #0 \n" /* zero data for relu */ \
"fmax s8, s8, s1 \n" /* relu */ \
"str s8, [%[out]] \n" /* save result */
#define SGEMV_OUT_1_RELU6 \
/* end */ \
"4: \n" /* end */ \
"movi d1, #0 \n" /* zero data for relu6 */ \
"fmov s2, %w[six] \n" /* mov six to s2 */ \
"fmax s8, s8, s1 \n" /* relu6 */ \
"fmin s8, s8, s2 \n" /* relu6 */ \
"str s8, [%[out]] \n" /* save result */
#define SGEMV_OUT_1_LEAKEY_RELU \
/* end */ \
"4: \n" /* end */ \
"fmov s1, %w[alpha] \n" /* mov alpha to s1 */ \
"fcmp s8, #0 \n" /* cmp with zero*/ \
"bge 5f \n" /* if ge zero */ \
"fmul s8, s8, s1 \n" /* out * alpha */ \
"5: \n" /* leakey relu label */ \
"str s8, [%[out]] \n" /* save result */
#else // __aarch64__
#define SGEMV_IN_4 \
......@@ -841,14 +995,13 @@ bool sgemv(const float *A,
"vmla.f32 q2, q5, q11 @ mul add\n" \
"vmla.f32 q3, q5, q13 @ mul add\n" \
"bne 1b @ jump to main loop\n" \
/* pair add to final result */ \
"2: @ pair add \n" \
"vpadd.f32 d8, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d9, d2, d3 @ pair add, first step\n" \
"vpadd.f32 d10, d4, d5 @ pair add, first step\n" \
"vpadd.f32 d11, d6, d7 @ pair add, first step\n" \
"vpadd.f32 d0, d8, d9 @ pair add, second step\n" \
"vpadd.f32 d1, d10, d11 @ pair add, second step\n" /* check tails */ \
"vpadd.f32 d1, d10, d11 @ pair add, second step\n" \
"cmp %[tail], #1 @ check whether has tail\n" \
"blt 4f @ jump to end\n" \
"3: @ tail loop\n" \
......@@ -876,7 +1029,7 @@ bool sgemv(const float *A,
"bne 1b @ jump to main loop\n" \
"2: @ end processing\n" \
"vpadd.f32 d2, d0, d1 @ pair add, first step\n" \
"vpadd.f32 d0, d2, d2 @ pair add, final step\n"/*check tails*/ \
"vpadd.f32 d0, d2, d2 @ pair add, final step\n" \
"cmp %[tail], #1 @ check whether has mid cols\n" \
"blt 4f @ jump to end\n" \
"3: @ tail loop\n" \
......@@ -898,6 +1051,25 @@ bool sgemv(const float *A,
"vmax.f32 q0, q0, q1 @ relu\n" \
"vst1.32 {d0-d1}, [%[out]] @ save result\n"
#define SGEMV_OUT_4_RELU6 \
/* end */ \
"4: @ end\n" \
"vmov.i32 q1, #0 @ zero for relu6\n" \
"vdup.f32 q2, %[six] @ six for relu6\n" \
"vmax.f32 q0, q0, q1 @ relu6\n" \
"vmin.f32 q0, q0, q2 @ relu6\n" \
"vst1.32 {d0-d1}, [%[out]] @ save result\n"
#define SGEMV_OUT_4_LEAKEY_RELU \
/* end */ \
"4: @ end\n" \
"vmov.i32 q1, #0 @ zero for leakey relu\n" \
"vdup.f32 q2, %[alpha] @ alpha for leakey relu\n" \
"vcge.f32 q3, q0, q1 @ vcgeq_f32 \n" \
"vmul.f32 q4, q0, q2 @ vmulq_f32 \n" \
"vbif q0, q4, q3 @ choose \n" \
"vst1.32 {d0-d1}, [%[out]] @ save result\n"
#define SGEMV_OUT_1 \
/* end */ \
"4: @ end\n" \
......@@ -909,14 +1081,36 @@ bool sgemv(const float *A,
"vmov.i32 d1, #0 @ zero for relu\n" \
"vmax.f32 d0, d0, d1 @ relu\n" \
"vst1.32 {d0[0]}, [%[out]] @ save result\n"
#define SGEMV_OUT_1_RELU6 \
/* end */ \
"4: @ end\n" \
"vmov.i32 d1, #0 @ zero for relu6\n" \
"vdup.f32 d4, %[six] @ six for relu6\n" \
"vmax.f32 d0, d0, d1 @ relu6\n" \
"vmin.f32 d0, d0, d4 @ relu6\n" \
"vst1.32 {d0[0]}, [%[out]] @ save result\n"
#define SGEMV_OUT_1_LEAKEY_RELU \
/* end */ \
"4: @ end\n" \
"vmov.i32 d2, #0 @ zero for leakey relu\n" \
"vdup.f32 d3, %[alpha] @ alpha for leakey relu\n" \
"vcge.f32 d6, d0, d2 @ vcgeq_f32 \n" \
"vmul.f32 d8, d0, d3 @ vmulq_f32 \n" \
"vbif d0, d8, d6 @ choose \n" \
"vst1.32 {d0[0]}, [%[out]] @ save result\n"
#endif
// clang-format on
void sgemv(const bool transA,
const int M,
void sgemv(const int M,
const int N,
const float *A,
const float *x,
float *y) {
float *y,
bool flag_bias,
const float *bias) {
float *data_out = y;
const float *data_in = x;
const float *weights_ptr = A;
......@@ -926,7 +1120,6 @@ void sgemv(const bool transA,
#ifdef __aarch64__
int out_cnt = M >> 3;
#pragma omp parallel for
for (int j = 0; j < out_cnt; j++) {
int out_idx = j * 8;
......@@ -940,9 +1133,22 @@ void sgemv(const bool transA,
const float *ptr_w5 = ptr_w4 + N;
const float *ptr_w6 = ptr_w5 + N;
const float *ptr_w7 = ptr_w6 + N;
const float *bias_ptr = bias + out_idx;
float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
if (flag_bias) {
bias_local[0] = bias_ptr[0];
bias_local[1] = bias_ptr[1];
bias_local[2] = bias_ptr[2];
bias_local[3] = bias_ptr[3];
bias_local[4] = bias_ptr[4];
bias_local[5] = bias_ptr[5];
bias_local[6] = bias_ptr[6];
bias_local[7] = bias_ptr[7];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8
// clang-format off
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -954,35 +1160,12 @@ void sgemv(const bool transA,
[w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"cc",
"memory");
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_local)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "cc", "memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -992,24 +1175,17 @@ void sgemv(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float tmp[4];
float tmp1[4];
float tmp2[4];
float tmp3[4];
float tmp4[4];
asm volatile(
SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out),
[tmp] "r"(tmp),
[tmp1] "r"(tmp1),
[tmp2] "r"(tmp2),
[tmp3] "r"(tmp3),
[tmp4] "r"(tmp4)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc");
}
#else // __aarch64__
int out_cnt = M >> 2;
......@@ -1022,10 +1198,20 @@ void sgemv(const bool transA,
const float *ptr_w1 = ptr_w0 + N;
const float *ptr_w2 = ptr_w1 + N;
const float *ptr_w3 = ptr_w2 + N;
float bias0 = 0.f;
float bias1 = 0.f;
float bias2 = 0.f;
float bias3 = 0.f;
if (flag_bias) {
bias0 = bias[out_idx];
bias1 = bias[out_idx + 1];
bias2 = bias[out_idx + 2];
bias3 = bias[out_idx + 3];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4
// clang-format off
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1033,23 +1219,16 @@ void sgemv(const bool transA,
[w3] "+r"(ptr_w3),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"cc",
: [out] "r"(ptr_out),
[bias0] "r"(bias0),
[bias1] "r"(bias1),
[bias2] "r"(bias2),
[bias3] "r"(bias3)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "cc",
"memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1059,23 +1238,28 @@ void sgemv(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif // __aarch64__
}
void sgemv_relu(const bool transA,
const int M,
void sgemv_relu(const int M,
const int N,
const float *A,
const float *x,
float *y) {
float *y,
bool flag_bias,
const float *bias) {
float *data_out = y;
const float *data_in = x;
const float *weights_ptr = A;
......@@ -1098,9 +1282,22 @@ void sgemv_relu(const bool transA,
const float *ptr_w5 = ptr_w4 + N;
const float *ptr_w6 = ptr_w5 + N;
const float *ptr_w7 = ptr_w6 + N;
const float *bias_ptr = bias + out_idx;
float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
if (flag_bias) {
bias_local[0] = bias_ptr[0];
bias_local[1] = bias_ptr[1];
bias_local[2] = bias_ptr[2];
bias_local[3] = bias_ptr[3];
bias_local[4] = bias_ptr[4];
bias_local[5] = bias_ptr[5];
bias_local[6] = bias_ptr[6];
bias_local[7] = bias_ptr[7];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8_RELU
// clang-format off
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1112,35 +1309,12 @@ void sgemv_relu(const bool transA,
[w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"cc",
"memory");
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_local)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "cc", "memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1150,13 +1324,17 @@ void sgemv_relu(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(
SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else // __aarch64__
......@@ -1170,10 +1348,20 @@ void sgemv_relu(const bool transA,
const float *ptr_w1 = ptr_w0 + N;
const float *ptr_w2 = ptr_w1 + N;
const float *ptr_w3 = ptr_w2 + N;
float bias0 = 0.f;
float bias1 = 0.f;
float bias2 = 0.f;
float bias3 = 0.f;
if (flag_bias) {
bias0 = bias[out_idx];
bias1 = bias[out_idx + 1];
bias2 = bias[out_idx + 2];
bias3 = bias[out_idx + 3];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4_RELU
// clang-format off
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1181,23 +1369,16 @@ void sgemv_relu(const bool transA,
[w3] "+r"(ptr_w3),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"cc",
: [out] "r"(ptr_out),
[bias0] "r"(bias0),
[bias1] "r"(bias1),
[bias2] "r"(bias2),
[bias3] "r"(bias3)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "cc",
"memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1207,31 +1388,36 @@ void sgemv_relu(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif // __aarch64__
}
void sgemv_bias(const bool transA,
const int M,
const int N,
const float *A,
const float *x,
float *y,
const float *bias) {
void sgemv_relu6(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
const float six) {
float *data_out = y;
const float *data_in = x;
const float *weights_ptr = A;
int cnt = N >> 3;
int tail = N & 7;
float32x4_t vsix = vdupq_n_f32(six);
#ifdef __aarch64__
int out_cnt = M >> 3;
#pragma omp parallel for
......@@ -1248,9 +1434,21 @@ void sgemv_bias(const bool transA,
const float *ptr_w6 = ptr_w5 + N;
const float *ptr_w7 = ptr_w6 + N;
const float *bias_ptr = bias + out_idx;
float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
if (flag_bias) {
bias_local[0] = bias_ptr[0];
bias_local[1] = bias_ptr[1];
bias_local[2] = bias_ptr[2];
bias_local[3] = bias_ptr[3];
bias_local[4] = bias_ptr[4];
bias_local[5] = bias_ptr[5];
bias_local[6] = bias_ptr[6];
bias_local[7] = bias_ptr[7];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8
// clang-format off
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU6
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1262,35 +1460,13 @@ void sgemv_bias(const bool transA,
[w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"cc",
"memory");
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_local),
[vsix] "w" (vsix)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "cc", "memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1300,14 +1476,17 @@ void sgemv_bias(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float bias0 = bias[j];
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else // __aarch64__
......@@ -1321,14 +1500,20 @@ void sgemv_bias(const bool transA,
const float *ptr_w1 = ptr_w0 + N;
const float *ptr_w2 = ptr_w1 + N;
const float *ptr_w3 = ptr_w2 + N;
float bias0 = bias[out_idx];
float bias1 = bias[out_idx + 1];
float bias2 = bias[out_idx + 2];
float bias3 = bias[out_idx + 3];
float bias0 = 0.f;
float bias1 = 0.f;
float bias2 = 0.f;
float bias3 = 0.f;
if (flag_bias) {
bias0 = bias[out_idx];
bias1 = bias[out_idx + 1];
bias2 = bias[out_idx + 2];
bias3 = bias[out_idx + 3];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4
// clang-format off
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU6
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1340,23 +1525,13 @@ void sgemv_bias(const bool transA,
[bias0] "r"(bias0),
[bias1] "r"(bias1),
[bias2] "r"(bias2),
[bias3] "r"(bias3)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"cc",
[bias3] "r"(bias3),
[six] "r" (six)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "cc",
"memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1366,30 +1541,35 @@ void sgemv_bias(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float bias0 = bias[j];
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif // __aarch64__
}
void sgemv_bias_relu(const bool transA,
const int M,
const int N,
const float *A,
const float *x,
float *y,
const float *bias) {
void sgemv_leakey_relu(const int M,
const int N,
const float *A,
const float *x,
float *y,
bool flag_bias,
const float *bias,
const float alpha) {
float *data_out = y;
const float *data_in = x;
const float *weights_ptr = A;
int cnt = N >> 3;
int tail = N & 7;
float32x4_t valpha = vdupq_n_f32(alpha);
#ifdef __aarch64__
int out_cnt = M >> 3;
#pragma omp parallel for
......@@ -1406,9 +1586,21 @@ void sgemv_bias_relu(const bool transA,
const float *ptr_w6 = ptr_w5 + N;
const float *ptr_w7 = ptr_w6 + N;
const float *bias_ptr = bias + out_idx;
float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
if (flag_bias) {
bias_local[0] = bias_ptr[0];
bias_local[1] = bias_ptr[1];
bias_local[2] = bias_ptr[2];
bias_local[3] = bias_ptr[3];
bias_local[4] = bias_ptr[4];
bias_local[5] = bias_ptr[5];
bias_local[6] = bias_ptr[6];
bias_local[7] = bias_ptr[7];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU
// clang-format off
asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_LEAKEY_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1420,35 +1612,13 @@ void sgemv_bias_relu(const bool transA,
[w7] "+r"(ptr_w7),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25",
"cc",
"memory");
: [out] "r"(ptr_out), [bias_ptr] "r"(bias_local),
[valpha] "w" (valpha)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
"v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "cc", "memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1458,14 +1628,17 @@ void sgemv_bias_relu(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float bias0 = bias[j];
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha)
: "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory");
}
#else // __aarch64__
......@@ -1479,14 +1652,20 @@ void sgemv_bias_relu(const bool transA,
const float *ptr_w1 = ptr_w0 + N;
const float *ptr_w2 = ptr_w1 + N;
const float *ptr_w3 = ptr_w2 + N;
float bias0 = bias[out_idx];
float bias1 = bias[out_idx + 1];
float bias2 = bias[out_idx + 2];
float bias3 = bias[out_idx + 3];
float bias0 = 0.f;
float bias1 = 0.f;
float bias2 = 0.f;
float bias3 = 0.f;
if (flag_bias) {
bias0 = bias[out_idx];
bias1 = bias[out_idx + 1];
bias2 = bias[out_idx + 2];
bias3 = bias[out_idx + 3];
}
int cnt_loop = cnt;
int tail_loop = tail;
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU
// clang-format off
asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_LEAKEY_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[w1] "+r"(ptr_w1),
......@@ -1498,23 +1677,13 @@ void sgemv_bias_relu(const bool transA,
[bias0] "r"(bias0),
[bias1] "r"(bias1),
[bias2] "r"(bias2),
[bias3] "r"(bias3)
: "q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"cc",
[bias3] "r"(bias3),
[alpha] "r" (alpha)
: "q0", "q1", "q2", "q3", "q4",
"q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12", "q13", "cc",
"memory");
// clang-format on
}
//! deal with remains
#pragma omp parallel for
......@@ -1524,14 +1693,18 @@ void sgemv_bias_relu(const bool transA,
const float *ptr_w0 = weights_ptr + (N * j);
int cnt_loop = cnt;
int tail_loop = tail;
float bias0 = bias[j];
asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0)
: "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory");
float bias0 = 0.f;
if (flag_bias) {
bias0 = bias[j];
}
asm volatile(
SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU
: [in] "+r"(ptr_in),
[w0] "+r"(ptr_w0),
[cnt] "+r"(cnt_loop),
[tail] "+r"(tail_loop)
: [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha)
: "q0", "q1", "q3", "q4", "q12", "q13", "q14", "q15", "cc", "memory");
}
#endif // __aarch64__
}
......
......@@ -17,23 +17,26 @@
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/device_info.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
// TODO(xxx): fixme now only support transA = false
bool sgemv(const float* A,
const float* x,
float* y,
bool sgemv(const float *A,
const float *x,
float *y,
bool transA,
int M,
int N,
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx);
const float *bias,
bool flag_act,
lite_api::ActivationType act,
const ARMContext *ctx,
float six = 6.f,
float alpha = 1.f);
} // namespace math
} // namespace arm
......
......@@ -42,8 +42,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int stride = param.strides[0];
int threads = ctx.threads();
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
......@@ -51,28 +49,28 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]);
bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh);
bool ks_equal = (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2));
bool flag_dw_5x5 = (paddings[0] == paddings[2]) &&
((kw == 5 && stride == 1) || (kw == 5 && stride == 2));
bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
/// dw conv impl
if (param.groups == ic && ic == oc && ks_equal && no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation && pads_all_equal) {
/// winograd conv impl
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) {
/// direct conv impl
chin * chout < 4 * hin * win && ks_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking direct conv";
} else {
......@@ -109,7 +107,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1);
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
......@@ -154,7 +152,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1);
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
......
......@@ -52,7 +52,10 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
} else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
if (param.strides[0] == 2) { // conv5x5s2_dw
auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) {
// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
......@@ -63,10 +66,11 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
flag_trans_weights_ = true;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else {
flag_trans_weights_ = false;
LOG(FATAL)
<< "5x5 depthwise conv only support stride == 1 or stride == 2";
}
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......
......@@ -93,9 +93,11 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) {
b_data = bias_.data<float>();
}
bool flag_relu = false;
bool flag_act = false;
lite_api::ActivationType act;
if (param.activation_type == "relu") {
flag_relu = true;
act = lite_api::ActivationType::kRelu;
flag_act = true;
}
if (flag_gemm_) {
operators::ActivationParam act_param;
......@@ -119,7 +121,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_act);
}
} else {
for (int i = 0; i < m_; ++i) {
......@@ -133,7 +135,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_,
param.bias != nullptr,
b_data,
flag_relu,
flag_act,
act,
&ctx);
}
}
......
......@@ -233,8 +233,17 @@ void MatMulCompute::Run() {
int ldb = n_;
int ldc = n_;
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx);
lite::arm::math::sgemv(x_data,
y_data,
o_data,
false,
m_,
k_,
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
if (fabsf(alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= alpha;
......
......@@ -50,8 +50,17 @@ void MulCompute::Run() {
k_ = x_w;
auto& ctx = this->ctx_->template As<ARMContext>();
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx);
lite::arm::math::sgemv(x_data,
y_data,
o_data,
false,
m_,
k_,
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
} else {
constexpr bool is_tranposed_y = false;
......
......@@ -131,7 +131,7 @@ class FcOPTest : public arena::TestCase {
1.f,
0.f,
true,
flag_bias,
static_cast<int>(flag_bias),
false);
} else {
basic_gemm(false,
......
......@@ -46,14 +46,19 @@ DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height");
DEFINE_int32(pad_w, 1, "pad width");
DEFINE_int32(pad_h0, 1, "pad top");
DEFINE_int32(pad_h1, 1, "pad bottom");
DEFINE_int32(pad_w0, 1, "pad left");
DEFINE_int32(pad_w1, 1, "pad right");
DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu");
DEFINE_int32(flag_act,
0,
"do activation"); // 0-no act, 1-relu, 2-relu6, 4-leakyrelu
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
......@@ -98,9 +103,10 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
int flag_act,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
const std::vector<int>& power_mode,
const float leakey_relu_scale) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
......@@ -118,13 +124,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
param.strides = strides;
param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilas);
param.fuse_relu = flag_relu;
param.groups = group;
if (flag_relu) {
const float six = 6.f;
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type =
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.activation_param = act_param;
}
......@@ -205,7 +218,9 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
pads[2],
pads[0],
flag_bias,
flag_relu);
flag_act,
six,
leakey_relu_scale);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
......@@ -254,22 +269,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
}
}
}
LOG(INFO) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " successed!!\n";
}
}
}
......@@ -287,12 +300,14 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
int flag_act,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
const std::vector<int>& power_mode,
const float leakey_relu_scale) {}
#endif // LITE_WITH_ARM
#if 1 /// 3x3dw
// TODO(chenjiaoAngel): fix me, diff: 3x3 depthwise conv
#if 0 /// 3x3dw
TEST(TestConv3x3DW, test_conv3x3_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
......@@ -301,7 +316,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
......@@ -310,6 +325,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
dims.push_back(DDim({batch, c, h, h}));
}
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims,
weights_dim,
c,
......@@ -317,9 +333,10 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{1, 2, 4},
{FLAGS_power_mode});
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
......@@ -335,28 +352,41 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
#if 1 /// 5x5dw
TEST(TestConv5x5DW, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
#ifdef __aarch64__
// TODO(chenjiaoAngel): fix me, diff: arm64 5x5s2 depthwise conv
for (auto& stride : {1}) {
#else
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
#endif
for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 15, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 56}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims,
weights_dim,
c,
{stride, stride},
{pad_left, pad_right, pad_top, pad_bottom},
{1, 1},
flag_bias,
flag_act,
{4},
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
test_conv_fp32(dims,
weights_dim,
c,
{stride, stride},
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_power_mode});
}
}
}
......@@ -373,7 +403,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
for (auto& cout : {1, 5, 16, 37}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) {
continue;
......@@ -384,6 +414,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims,
weights_dim,
g,
......@@ -391,9 +422,10 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
{0, 0, 0, 0},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{1, 2, 4},
{FLAGS_power_mode});
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
......@@ -403,24 +435,29 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
}
#endif /// conv1x1s1
#if 1 /// conv3x3s1
// TODO(MyPandaShaoxiang): fix me, diff: 3x3s1 winograd
#if 0 /// conv3x3s1
TEST(TestConv3x3s1, test_conv_3x3s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) {
for (auto& cout : {1, 5, 8, 32, 48}) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) {
for (auto& cin : {1, 3, 8, 8}) {
for (auto& cout : {1, 5, 32, 48}) {
for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) {
for (auto& h : {1, 3, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
if (cin == 1 && cout ==1) {
continue;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims,
weights_dim,
1,
......@@ -428,9 +465,10 @@ TEST(TestConv3x3s1, test_conv_3x3s1) {
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_power_mode});
flag_act,
{4},
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
......@@ -446,21 +484,25 @@ TEST(TestConv3x3s1, test_conv_3x3s1) {
#if 1 /// conv3x3s2
TEST(TestConv3x3s2, test_conv_3x3s2) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 8, 32}) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) {
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 3, 9, 32}) {
for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) {
for (auto& h : {3, 7, 15, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
if (cin == 1 && cout == 1) {
continue;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims,
weights_dim,
1,
......@@ -468,9 +510,10 @@ TEST(TestConv3x3s2, test_conv_3x3s2) {
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{1, 2, 4},
{FLAGS_power_mode});
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
......@@ -486,29 +529,40 @@ TEST(TestConv3x3s2, test_conv_3x3s2) {
#if 1 /// random param conv
TEST(TestConvRand, test_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) {
for (auto& cout : {1, 5, 8, 16}) {
for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) {
for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& pad_left : {0, 2}) {
for (auto& pad_right : {0, 2}) {
for (auto& pad_top : {0, 2}) {
for (auto& pad_bottom : {0, 2}) {
for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
if (cin % g != 0 || cout % g != 0) {
continue;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) {
for (auto& h : {1, 3, 19, 32}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
// skip 3x3 depthwise conv
if (g == cin && cin == cout && kw == 3 &&
kh == 3) {
break;
}
// skip 3x3s1 direct conv
if (g == 1 && (cin != 1 || cout != 1) &&
kw == 3 && kh == 3 && stride == 1) {
break;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(
dims,
weights_dim,
......@@ -517,9 +571,10 @@ TEST(TestConvRand, test_conv_rand) {
{pad_top, pad_bottom, pad_left, pad_right},
{dila, dila},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_power_mode});
flag_act,
{4},
{FLAGS_power_mode},
leakey_relu_scale);
}
}
}
......@@ -551,11 +606,12 @@ TEST(TestConvCustom, test_conv_fp32_custom_size) {
FLAGS_kernel_w}),
FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w},
{FLAGS_pad_h0, FLAGS_pad_h1, FLAGS_pad_w0, FLAGS_pad_w1},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_flag_act,
{FLAGS_threads},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_leakey_relu_alpha);
}
#endif // custom
......@@ -291,7 +291,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads[2],
pads[0],
flag_bias,
flag_relu);
static_cast<int>(flag_relu));
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8,
scale_out.data(),
......@@ -362,6 +362,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
......@@ -467,7 +468,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 75, 32, 28}) {
for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
......@@ -479,7 +480,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......@@ -494,14 +495,14 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
......@@ -513,7 +514,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......@@ -527,8 +528,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
#if 1 /// conv1x1s1
TEST(TestConv1x1s1Int8, test_conv1x1s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 11, 32}) {
for (auto& cout : {1, 5, 16, 37}) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 17}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
......@@ -538,7 +539,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
}
DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 32, 56, 1}) {
for (auto& h : {1, 9, 16, 33}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
......@@ -550,7 +551,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......@@ -564,8 +565,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
#if 1 /// conv3x3s1
TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) {
for (auto& cout : {1, 5, 8, 32, 48}) {
for (auto& cin : {1, 3, 8, 33}) {
for (auto& cout : {1, 5, 33}) {
for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) {
for (auto& pad_left : {1, 2}) {
......@@ -575,7 +576,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) {
for (auto& h : {1, 7, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
......@@ -587,7 +588,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......@@ -604,8 +605,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
#if 1 /// conv3x3s2
TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 8, 32}) {
for (auto& cin : {1, 3, 31}) {
for (auto& cout : {1, 5, 33}) {
for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) {
for (auto& pad_left : {1, 2}) {
......@@ -615,7 +616,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) {
for (auto& h : {1, 7, 19, 33}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
......@@ -627,7 +628,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
{1, 1},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......@@ -644,8 +645,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
#if 1 /// random param conv
TEST(TestConvRandInt8, test_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) {
for (auto& cout : {1, 5, 8, 16}) {
for (auto& cin : {1, 17}) {
for (auto& cout : {1, 8, 17}) {
for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) {
......@@ -658,12 +659,12 @@ TEST(TestConvRandInt8, test_conv_rand) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
if (cin % g != 0 || cout % g != 0) {
continue;
break;
}
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) {
for (auto& h : {1, 3, 5, 19}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
......@@ -676,7 +677,7 @@ TEST(TestConvRandInt8, test_conv_rand) {
{dila, dila},
flag_bias,
flag_relu,
{1, 2, 4},
{4},
{FLAGS_power_mode});
}
}
......
......@@ -37,7 +37,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemv: M");
......
......@@ -37,7 +37,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm_c4: M");
......
......@@ -38,11 +38,19 @@ DEFINE_int32(K, 512, "sgemv: K");
DEFINE_bool(traA, false, "gemv: A transpose");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_int32(flag_act, 0, "do act");
DEFINE_bool(flag_bias, false, "with bias");
bool test_sgemv(
bool tra, int m, int k, bool has_bias, bool has_relu, int cls, int ths) {
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_double(clipped_coef, 6.0, "clipped relu coef");
bool test_sgemv(bool tra,
int m,
int k,
bool has_bias,
int flag_act,
int cls,
int ths,
float six = 6.f,
float alpha = 1.f) {
Tensor ta;
Tensor tb;
Tensor tc;
......@@ -68,8 +76,7 @@ bool test_sgemv(
fill_tensor_rand(tbias, -1.f, 1.f);
LOG(INFO) << "sgemv M: " << m << ", K: " << k
<< ", transA: " << (tra ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
......@@ -78,10 +85,29 @@ bool test_sgemv(
auto dc = tc.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity;
if (flag_act == 1) {
act = paddle::lite_api::ActivationType::kRelu;
} else if (flag_act == 2) {
act = paddle::lite_api::ActivationType::kRelu6;
} else if (flag_act == 4) {
act = paddle::lite_api::ActivationType::kLeakyRelu;
}
if (FLAGS_check_result) {
basic_gemv(
m, k, da, db, dbias, dc_basic, 1.f, 0.f, tra, has_bias, has_relu);
basic_gemv(m,
k,
da,
db,
dbias,
dc_basic,
1.f,
0.f,
tra,
has_bias,
flag_act,
six,
alpha);
}
paddle::lite::profile::Timer t0;
//! compute
......@@ -92,15 +118,37 @@ bool test_sgemv(
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
/// warmup
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemv(
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx);
paddle::lite::arm::math::sgemv(da,
db,
dc,
tra,
m,
k,
has_bias,
dbias,
flag_act > 0,
act,
&ctx,
six,
alpha);
}
t0.Reset();
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
paddle::lite::arm::math::sgemv(
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx);
paddle::lite::arm::math::sgemv(da,
db,
dc,
tra,
m,
k,
has_bias,
dbias,
flag_act > 0,
act,
&ctx,
six,
alpha);
t0.Stop();
}
LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls
......@@ -125,7 +173,7 @@ bool test_sgemv(
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "saber result: ";
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
......@@ -144,22 +192,31 @@ TEST(TestLiteSgemv, Sgemv) {
LOG(INFO) << "run basic sgemv test";
for (auto& m : {1, 3, 8, 21, 32, 397}) {
for (auto& k : {1, 3, 8, 17, 59, 234}) {
for (auto& tra : {true, false}) {
for (auto& tra : {false, true}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemv(
tra, m, k, has_bias, has_relu, FLAGS_cluster, th);
float six = 6.f;
float alpha = 8.88f;
auto flag = test_sgemv(tra,
m,
k,
has_bias,
flag_act,
FLAGS_cluster,
th,
six,
alpha);
if (flag) {
LOG(INFO) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", flag act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", flag_act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " failed\n";
}
......@@ -180,15 +237,17 @@ TEST(TestSgemvCustom, Sgemv_custom) {
FLAGS_M,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_flag_act,
FLAGS_cluster,
FLAGS_threads);
FLAGS_threads,
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
<< ", act: " << FLAGS_flag_act << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " passed!!";
<< ", act: " << FLAGS_flag_act << " passed!!";
}
......@@ -177,7 +177,9 @@ static void basic_gemv(int m,
type2 beta,
bool trans_a = false,
bool flag_bias = false,
bool flag_relu = false) {
int flag_act = false,
float six = 6.f,
float leakey_relu_alpha = 1.f) {
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
......@@ -195,8 +197,15 @@ static void basic_gemv(int m,
sum += av * b[j];
}
type2 tmp = alpha * sum + beta * c[i] + bias_data;
if (flag_relu) {
c[i] = tmp > (type2)0 ? tmp : (type2)0;
if (flag_act > 0) {
if (flag_act == 1) { // relu
c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six;
} else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
}
} else {
c[i] = tmp;
}
......@@ -230,7 +239,9 @@ static void conv_basic(const Dtype1* din,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
int act_type,
float six = 6.f,
float scale = 1.f) {
Dtype2 beta = 0;
auto src_data = din;
auto dst_data_ref = dout;
......@@ -280,10 +291,27 @@ static void conv_basic(const Dtype1* din,
}
}
}
if (flag_relu) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
if (act_type > 0) {
// 1-relu 2-relu6 4-leakyrelu
if (act_type == 1) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
} else if (act_type == 2) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
? dst_data_ref[out_idx]
: (Dtype2)six;
} else if (act_type == 4) {
dst_data_ref[out_idx] =
dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)(dst_data_ref[out_idx] * scale);
} else {
printf("this act type: %d does not support \n", act_type);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册