提交 42bbd157 编写于 作者: Y yiicy 提交者: GitHub

[ARM] 5x5dw and sgemv support fuse activation, test=develop (#2797)

* refactor 5x5s1 dw conv armv8, test=develop

* [ARM] refactor depthwise conv 5x5s1, and support relu6, leakey relu, test=develop

* [ARM] sgemv support fuse relu6 and leakey relu,test=develop

* [ARM] reduce some conv ut case, test=develop

* [ARM] fix 5x5dw conv pick kernel bug, test=develop

* fix code style, test=develop

* [ARM] fix sgemv fuse relu6 bug, test=develop

* [ARM] fix fp32 5x5s1 dw bug, test=develop

* [ARM] fix fp32 5x5 dw conv pick kernel bug, test=develop
上级 7a0e3fd7
...@@ -614,8 +614,8 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -614,8 +614,8 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"blt 3f \n" "blt 3f \n"
#define LEFT_RESULT_S1_LEAKY_RELU \ #define LEFT_RESULT_S1_LEAKY_RELU \
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
...@@ -639,7 +639,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -639,7 +639,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \
"cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
...@@ -657,7 +657,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -657,7 +657,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\ \
"cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \
...@@ -802,7 +802,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -802,7 +802,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define MID_RESULT_S1_LEAKY_RELU \ #define MID_RESULT_S1_LEAKY_RELU \
"movi v21.4s, #0 \n" \ "movi v21.4s, #0 \n" \
"cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v12.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\ \
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
...@@ -824,7 +824,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -824,7 +824,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v13.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\ \
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
...@@ -846,7 +846,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -846,7 +846,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v14.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \
\ \
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
...@@ -861,7 +861,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -861,7 +861,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \
\ \
"cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v15.4s, v21.4s \n" /* vcgeq_f32 */ \
"fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
...@@ -980,7 +980,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -980,7 +980,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#define RIGHT_RESULT_S1_LEAKY_RELU \ #define RIGHT_RESULT_S1_LEAKY_RELU \
"movi v1.4s, #0 \n" \ "movi v1.4s, #0 \n" \
"cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \ "fcmge v20.4s, v12.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \
\ \
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
...@@ -999,7 +999,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -999,7 +999,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\ \
"cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \ "fcmge v20.4s, v13.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \
\ \
...@@ -1017,7 +1017,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1017,7 +1017,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\ \
"cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \ "fcmge v20.4s, v14.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
\ \
...@@ -1028,7 +1028,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1028,7 +1028,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
\ \
"bif v14.16b, v24.16b, v18.16b \n" \ "bif v14.16b, v24.16b, v18.16b \n" \
\ \
"cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \ "fcmge v20.4s, v15.4s, v1.4s \n" /* vcgeq_f32 */ \
"fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \
\ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \
...@@ -1132,8 +1132,8 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1132,8 +1132,8 @@ void conv_depthwise_3x3s1_fp32(const float *din,
"prfm pldl1keep, [%[out1]]\n" \ "prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \ "prfm pldl1keep, [%[out2]]\n" \
\ \
"cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \
"fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \
\ \
......
...@@ -179,11 +179,11 @@ namespace math { ...@@ -179,11 +179,11 @@ namespace math {
#define LEAKY_RELU \ #define LEAKY_RELU \
"movi v0.4s, #0\n" /* for relu */ \ "movi v0.4s, #0\n" /* for relu */ \
"ldr x0, [%[outl], #88]\n" \ "ldr x0, [%[outl], #88]\n" \
"cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v1.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ \
"cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \
"ld1 {v9.4s}, [x0] \n" \ "ld1 {v9.4s}, [x0] \n" \
"cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v3.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \
"cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v4.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \
"ldr x0, [%[outl]] \n" \ "ldr x0, [%[outl]] \n" \
"fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \ "fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \ "fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \
...@@ -193,10 +193,10 @@ namespace math { ...@@ -193,10 +193,10 @@ namespace math {
"bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \ "bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \ "bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \
"bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \ "bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v3.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v4.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \ "fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \
"fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \ "fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \
"fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \ "fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \
......
...@@ -50,7 +50,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -50,7 +50,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias; bool flag_bias = param.bias;
int pad_h = paddings[0]; int pad_h = paddings[0];
int pad_w = paddings[1]; int pad_w = paddings[2];
const int threads = ctx->threads(); const int threads = ctx->threads();
int llc_size = ctx->llc_size() / 4; int llc_size = ctx->llc_size() / 4;
...@@ -477,7 +477,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, ...@@ -477,7 +477,7 @@ void conv_3x3s2_direct_int8(const int8_t* din,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias; bool flag_bias = param.bias;
int pad_h = paddings[0]; int pad_h = paddings[0];
int pad_w = paddings[1]; int pad_w = paddings[2];
const int threads = ctx->threads(); const int threads = ctx->threads();
//! set 1/4 l2 cache //! set 1/4 l2 cache
int llc_size = ctx->llc_size() / 4; int llc_size = ctx->llc_size() / 4;
......
...@@ -453,7 +453,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -453,7 +453,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
#define LEFT_RESULT_S2_LEAKY_RELU \ #define LEFT_RESULT_S2_LEAKY_RELU \
"ld1 {v22.4s}, [%[scale_ptr]] \n" \ "ld1 {v22.4s}, [%[scale_ptr]] \n" \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
\ \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
...@@ -475,7 +475,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -475,7 +475,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"ext v10.16b, v0.16b, v15.16b, #4 \n" \ "ext v10.16b, v0.16b, v15.16b, #4 \n" \
\ \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \ "fmul v12.4s, v16.4s, v22.4s \n" \
\ \
"ld1 {v20.4s}, [%[inptr3]] \n" \ "ld1 {v20.4s}, [%[inptr3]] \n" \
...@@ -543,7 +543,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -543,7 +543,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"bne 2b \n" "bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \ #define MID_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \ "fmul v12.4s, v16.4s, v22.4s \n" \
\ \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
...@@ -554,7 +554,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -554,7 +554,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \ "ext v10.16b, v0.16b, v15.16b, #4 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \ "fmul v12.4s, v17.4s, v22.4s \n" \
\ \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
...@@ -607,7 +607,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -607,7 +607,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"4: \n" "4: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \ #define RIGHT_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \ "fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
\ \
...@@ -617,7 +617,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -617,7 +617,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"bif v16.16b, v0.16b, %[wmask].16b \n" \ "bif v16.16b, v0.16b, %[wmask].16b \n" \
\ \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \ "fmul v12.4s, v17.4s, v22.4s \n" \
\ \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
......
...@@ -104,13 +104,13 @@ namespace math { ...@@ -104,13 +104,13 @@ namespace math {
"fmin v22.4s, v22.4s, %[vsix].4s\n" "fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \ #define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \ "movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
......
...@@ -709,7 +709,6 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, ...@@ -709,7 +709,6 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
"q15"); "q15");
#endif #endif
// clang-format on // clang-format on
int32_t* ptr_tmp = ptr_out0 - w_loop * 32;
block_inr0 = block_inr1; block_inr0 = block_inr1;
block_inr1 = block_inr2; block_inr1 = block_inr2;
block_inr2 = block_inr3; block_inr2 = block_inr3;
......
...@@ -200,13 +200,13 @@ namespace math { ...@@ -200,13 +200,13 @@ namespace math {
"fmin v22.4s, v22.4s, %[vsix].4s\n" "fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \ #define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \ "movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
......
...@@ -614,10 +614,10 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, ...@@ -614,10 +614,10 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC1_TRANS_FP32_LEAKY_RELU \ #define NCHWC1_TRANS_FP32_LEAKY_RELU \
"cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v4.4s, v0.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v5.4s, v1.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \
"fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \
"fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \ "fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \
...@@ -935,8 +935,8 @@ inline bool write_to_output_c1_fp32(const float* din, ...@@ -935,8 +935,8 @@ inline bool write_to_output_c1_fp32(const float* din,
"fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */
#define NCHWC2_TRANS_FP32_LEAKY_RELU \ #define NCHWC2_TRANS_FP32_LEAKY_RELU \
"cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \
"bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \
...@@ -1276,10 +1276,10 @@ inline bool write_to_output_c2_fp32(const float* din, ...@@ -1276,10 +1276,10 @@ inline bool write_to_output_c2_fp32(const float* din,
"fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */ "fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */
#define NCHWC4_TRANS_FP32_LEAKY_RELU \ #define NCHWC4_TRANS_FP32_LEAKY_RELU \
"cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v8.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v9.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v10.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \
"cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \
"fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \
...@@ -1754,15 +1754,15 @@ inline bool write_to_output_c4_fp32(const float* din, ...@@ -1754,15 +1754,15 @@ inline bool write_to_output_c4_fp32(const float* din,
"fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/ "fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/
#define NCHWC8_TRANS_FP32_LEAKY_RELU \ #define NCHWC8_TRANS_FP32_LEAKY_RELU \
"cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \
\ \
"cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \
"cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ "fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \
\ \
"fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \ "fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \
"fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \ "fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \
...@@ -2169,13 +2169,13 @@ inline void act_switch_c8_fp32(const float* din_ptr, ...@@ -2169,13 +2169,13 @@ inline void act_switch_c8_fp32(const float* din_ptr,
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define DO_LEAKY_RELU \ #define DO_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
...@@ -2217,7 +2217,7 @@ inline void act_switch_c8_fp32(const float* din_ptr, ...@@ -2217,7 +2217,7 @@ inline void act_switch_c8_fp32(const float* din_ptr,
"vbif q3, q8, q7 @ choose \n" \ "vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \ "vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \ "vbif q5, q12, q11 @ choose \n" \
"vbif q6, q13, q13 @ choose \n" "vbif q6, q14, q13 @ choose \n"
#define DO_STORE \ #define DO_STORE \
"subs %[cnt], #1 \n" \ "subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \ "vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
......
...@@ -123,20 +123,21 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, ...@@ -123,20 +123,21 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_5x5s1_fp32(const float* din, void conv_depthwise_5x5s1_fp32(float* dout,
float* dout, const float* din,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights, const float* weights,
const float* bias, const float* bias,
int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
const operators::ConvParam& param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_5x5s2_fp32(const float* din, void conv_depthwise_5x5s2_fp32(const float* din,
......
...@@ -188,7 +188,6 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -188,7 +188,6 @@ void conv1x1s1_gemm(const float* i_data,
if (n > 1) { if (n > 1) {
weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; weights_size_per_group = ((m_roundup * k + 15) / 16) * 16;
} }
//! use gemv when the output channel size = 1 //! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
...@@ -210,8 +209,11 @@ void conv1x1s1_gemm(const float* i_data, ...@@ -210,8 +209,11 @@ void conv1x1s1_gemm(const float* i_data,
k, k,
flag_bias, flag_bias,
bias_group, bias_group,
flag_relu, act_param.has_active,
ctx); act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else { } else {
sgemm_prepack(false, sgemm_prepack(false,
m, m,
...@@ -410,8 +412,11 @@ void conv_im2col_gemm(const float* i_data, ...@@ -410,8 +412,11 @@ void conv_im2col_gemm(const float* i_data,
k, k,
flag_bias, flag_bias,
bias_group, bias_group,
flag_relu, act_param.has_active,
ctx); act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else { } else {
int ldb = n; int ldb = n;
sgemm_prepack(false, sgemm_prepack(false,
...@@ -677,7 +682,8 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -677,7 +682,8 @@ void conv_depthwise_5x5_fp32(const void* din,
const float* scale) { const float* scale) {
auto paddings = *param.paddings; auto paddings = *param.paddings;
auto act_param = param.activation_param; auto act_param = param.activation_param;
int pad = paddings[0]; int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1]; int stride = param.strides[1];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
...@@ -698,20 +704,21 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -698,20 +704,21 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param, act_param,
ctx); ctx);
} else if (stride == 1) { } else if (stride == 1) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<float*>(dout), reinterpret_cast<const float*>(din),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
pad,
flag_bias, flag_bias,
flag_relu, flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
param,
ctx); ctx);
} else { } else {
LOG(FATAL) << "unsupport this type 5x5 dw conv"; LOG(FATAL) << "unsupport this type 5x5 dw conv";
......
...@@ -137,13 +137,13 @@ void fill_bias_relu<int>(int* tensor, ...@@ -137,13 +137,13 @@ void fill_bias_relu<int>(int* tensor,
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define FILL_LEAKY_RELU \ #define FILL_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
......
此差异已折叠。
...@@ -17,23 +17,26 @@ ...@@ -17,23 +17,26 @@
#include <cmath> #include <cmath>
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
// TODO(xxx): fixme now only support transA = false bool sgemv(const float *A,
bool sgemv(const float* A, const float *x,
const float* x, float *y,
float* y,
bool transA, bool transA,
int M, int M,
int N, int N,
bool is_bias, bool is_bias,
const float* bias, const float *bias,
bool is_relu, bool flag_act,
const ARMContext* ctx); lite_api::ActivationType act,
const ARMContext *ctx,
float six = 6.f,
float alpha = 1.f);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -42,8 +42,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -42,8 +42,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int stride = param.strides[0]; int stride = param.strides[0];
int threads = ctx.threads(); int threads = ctx.threads();
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
int chin = param.x->dims()[1]; int chin = param.x->dims()[1];
int hin = param.x->dims()[2]; int hin = param.x->dims()[2];
int win = param.x->dims()[3]; int win = param.x->dims()[3];
...@@ -51,28 +49,28 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -51,28 +49,28 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int hout = param.output->dims()[2]; int hout = param.output->dims()[2];
int wout = param.output->dims()[3]; int wout = param.output->dims()[3];
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]);
bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh); bool ks_equal = (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2));
bool flag_dw_5x5 = (paddings[0] == paddings[2]) && bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
((kw == 5 && stride == 1) || (kw == 5 && stride == 2)); bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl /// select conv impl
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && ks_equal && no_dilation && flag_dw) {
/// dw conv impl
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking dw conv"; // VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation && pads_all_equal) { no_dilation && pads_all_equal) {
/// winograd conv impl // TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv"; // VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 && } else if (param.groups == 1 && kw == 3 && stride == 2 &&
chin * chout < 4 * hin * win && kps_equal && no_dilation) { chin * chout < 4 * hin * win && ks_equal && no_dilation) {
/// direct conv impl
impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking direct conv"; // VLOG(3) << "invoking direct conv";
} else { } else {
...@@ -109,7 +107,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -109,7 +107,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
...@@ -154,7 +152,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -154,7 +152,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
......
...@@ -52,7 +52,10 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -52,7 +52,10 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = lite::arm::math::conv_depthwise_3x3_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
} else if (kw == 5) { } else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32"; // VLOG(5) << "invoke 5x5 dw conv fp32";
if (param.strides[0] == 2) { // conv5x5s2_dw auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) {
// trans weights
constexpr int cblock = 4; constexpr int cblock = 4;
auto oc = w_dims[0]; auto oc = w_dims[0];
auto kh = w_dims[2]; auto kh = w_dims[2];
...@@ -63,10 +66,11 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -63,10 +66,11 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
lite::arm::math::conv_trans_weights_numc( lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw); w_data_in, w_data, oc, 1, cblock, kh * kw);
flag_trans_weights_ = true; flag_trans_weights_ = true;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else { } else {
flag_trans_weights_ = false; LOG(FATAL)
<< "5x5 depthwise conv only support stride == 1 or stride == 2";
} }
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
} }
......
...@@ -93,9 +93,11 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -93,9 +93,11 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false; bool flag_act = false;
lite_api::ActivationType act;
if (param.activation_type == "relu") { if (param.activation_type == "relu") {
flag_relu = true; act = lite_api::ActivationType::kRelu;
flag_act = true;
} }
if (flag_gemm_) { if (flag_gemm_) {
operators::ActivationParam act_param; operators::ActivationParam act_param;
...@@ -119,7 +121,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -119,7 +121,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu); lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_act);
} }
} else { } else {
for (int i = 0; i < m_; ++i) { for (int i = 0; i < m_; ++i) {
...@@ -133,7 +135,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -133,7 +135,8 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
k_, k_,
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
flag_relu, flag_act,
act,
&ctx); &ctx);
} }
} }
......
...@@ -233,8 +233,17 @@ void MatMulCompute::Run() { ...@@ -233,8 +233,17 @@ void MatMulCompute::Run() {
int ldb = n_; int ldb = n_;
int ldc = n_; int ldc = n_;
if (n_ == 1) { if (n_ == 1) {
lite::arm::math::sgemv( lite::arm::math::sgemv(x_data,
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); y_data,
o_data,
false,
m_,
k_,
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
if (fabsf(alpha - 1.f) > 1e-8f) { if (fabsf(alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) { for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= alpha; o_data[i] *= alpha;
......
...@@ -50,8 +50,17 @@ void MulCompute::Run() { ...@@ -50,8 +50,17 @@ void MulCompute::Run() {
k_ = x_w; k_ = x_w;
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
if (n_ == 1) { if (n_ == 1) {
lite::arm::math::sgemv( lite::arm::math::sgemv(x_data,
x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); y_data,
o_data,
false,
m_,
k_,
false,
nullptr,
false,
lite_api::ActivationType::kIndentity,
&ctx);
} else { } else {
constexpr bool is_tranposed_y = false; constexpr bool is_tranposed_y = false;
......
...@@ -131,7 +131,7 @@ class FcOPTest : public arena::TestCase { ...@@ -131,7 +131,7 @@ class FcOPTest : public arena::TestCase {
1.f, 1.f,
0.f, 0.f,
true, true,
flag_bias, static_cast<int>(flag_bias),
false); false);
} else { } else {
basic_gemm(false, basic_gemm(false,
......
...@@ -46,14 +46,19 @@ DEFINE_int32(out_channel, 32, "output channel"); ...@@ -46,14 +46,19 @@ DEFINE_int32(out_channel, 32, "output channel");
DEFINE_int32(group, 1, "group"); DEFINE_int32(group, 1, "group");
DEFINE_int32(kernel_h, 3, "kernel height"); DEFINE_int32(kernel_h, 3, "kernel height");
DEFINE_int32(kernel_w, 3, "kernel width"); DEFINE_int32(kernel_w, 3, "kernel width");
DEFINE_int32(pad_h, 1, "pad height"); DEFINE_int32(pad_h0, 1, "pad top");
DEFINE_int32(pad_w, 1, "pad width"); DEFINE_int32(pad_h1, 1, "pad bottom");
DEFINE_int32(pad_w0, 1, "pad left");
DEFINE_int32(pad_w1, 1, "pad right");
DEFINE_int32(stride_h, 1, "stride height"); DEFINE_int32(stride_h, 1, "stride height");
DEFINE_int32(stride_w, 1, "stride width"); DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height"); DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width"); DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu"); DEFINE_int32(flag_act,
0,
"do activation"); // 0-no act, 1-relu, 2-relu6, 4-leakyrelu
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_bool(flag_bias, true, "with bias"); DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
...@@ -98,9 +103,10 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -98,9 +103,10 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
const std::vector<int>& pads, const std::vector<int>& pads,
const std::vector<int>& dilas, const std::vector<int>& dilas,
bool flag_bias, bool flag_bias,
bool flag_relu, int flag_act,
const std::vector<int>& thread_num, const std::vector<int>& thread_num,
const std::vector<int>& power_mode) { const std::vector<int>& power_mode,
const float leakey_relu_scale) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
...@@ -118,13 +124,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -118,13 +124,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
param.strides = strides; param.strides = strides;
param.paddings = std::make_shared<std::vector<int>>(pads); param.paddings = std::make_shared<std::vector<int>>(pads);
param.dilations = std::make_shared<std::vector<int>>(dilas); param.dilations = std::make_shared<std::vector<int>>(dilas);
param.fuse_relu = flag_relu;
param.groups = group; param.groups = group;
if (flag_relu) { const float six = 6.f;
if (flag_act > 0) {
ActivationParam act_param; ActivationParam act_param;
act_param.has_active = true; act_param.has_active = true;
act_param.active_type = act_param.active_type = (paddle::lite_api::ActivationType)
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.activation_param = act_param; param.activation_param = act_param;
} }
...@@ -205,7 +218,9 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -205,7 +218,9 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
pads[2], pads[2],
pads[0], pads[0],
flag_bias, flag_bias,
flag_relu); flag_act,
six,
leakey_relu_scale);
} }
/// warm up /// warm up
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
...@@ -254,22 +269,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -254,22 +269,20 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group << ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", act: " << flag_act << ", threads: " << th
<< ", threads: " << th << ", power_mode: " << cls << ", power_mode: " << cls << " failed!!\n";
<< " failed!!\n";
} }
} }
} }
LOG(INFO) << "test fp32 conv: input: " << dim_in LOG(INFO) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim << ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", stride: " << strides[0] << ", " << strides[1] << ", " << pads[3] << ", stride: " << strides[0] << ", "
<< ", dila_: " << dilas[0] << ", " << dilas[1] << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group << ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", act: " << flag_act << ", threads: " << th
<< ", threads: " << th << ", power_mode: " << cls << ", power_mode: " << cls << " successed!!\n";
<< " successed!!\n";
} }
} }
} }
...@@ -287,12 +300,14 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -287,12 +300,14 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
const std::vector<int>& pads, const std::vector<int>& pads,
const std::vector<int>& dilas, const std::vector<int>& dilas,
bool flag_bias, bool flag_bias,
bool flag_relu, int flag_act,
const std::vector<int>& thread_num, const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {} const std::vector<int>& power_mode,
const float leakey_relu_scale) {}
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
#if 1 /// 3x3dw // TODO(chenjiaoAngel): fix me, diff: 3x3 depthwise conv
#if 0 /// 3x3dw
TEST(TestConv3x3DW, test_conv3x3_depthwise) { TEST(TestConv3x3DW, test_conv3x3_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
...@@ -301,7 +316,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { ...@@ -301,7 +316,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
for (auto& pad_top : {0, 1, 2}) { for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) { for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3}); DDim weights_dim({c, 1, 3, 3});
...@@ -310,6 +325,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { ...@@ -310,6 +325,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims, test_conv_fp32(dims,
weights_dim, weights_dim,
c, c,
...@@ -317,9 +333,10 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { ...@@ -317,9 +333,10 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {1, 2, 4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
} }
} }
} }
...@@ -335,28 +352,41 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { ...@@ -335,28 +352,41 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) {
#if 1 /// 5x5dw #if 1 /// 5x5dw
TEST(TestConv5x5DW, test_conv5x5_depthwise) { TEST(TestConv5x5DW, test_conv5x5_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
#ifdef __aarch64__
// TODO(chenjiaoAngel): fix me, diff: arm64 5x5s2 depthwise conv
for (auto& stride : {1}) {
#else
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) { #endif
for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 15, 32}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { for (auto& h : {1, 3, 15, 56}) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims, test_conv_fp32(dims,
weights_dim, weights_dim,
c, c,
{stride, stride}, {stride, stride},
{pad, pad, pad, pad}, {pad_left, pad_right, pad_top, pad_bottom},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
}
}
}
} }
} }
} }
...@@ -373,7 +403,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) { ...@@ -373,7 +403,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
for (auto& cout : {1, 5, 16, 37}) { for (auto& cout : {1, 5, 16, 37}) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims; std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; continue;
...@@ -384,6 +414,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) { ...@@ -384,6 +414,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims, test_conv_fp32(dims,
weights_dim, weights_dim,
g, g,
...@@ -391,9 +422,10 @@ TEST(TestConv1x1s1, test_conv1x1s1) { ...@@ -391,9 +422,10 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
{0, 0, 0, 0}, {0, 0, 0, 0},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {1, 2, 4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
} }
} }
} }
...@@ -403,24 +435,29 @@ TEST(TestConv1x1s1, test_conv1x1s1) { ...@@ -403,24 +435,29 @@ TEST(TestConv1x1s1, test_conv1x1s1) {
} }
#endif /// conv1x1s1 #endif /// conv1x1s1
#if 1 /// conv3x3s1 // TODO(MyPandaShaoxiang): fix me, diff: 3x3s1 winograd
#if 0 /// conv3x3s1
TEST(TestConv3x3s1, test_conv_3x3s1) { TEST(TestConv3x3s1, test_conv_3x3s1) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) { for (auto& cin : {1, 3, 8, 8}) {
for (auto& cout : {1, 5, 8, 32, 48}) { for (auto& cout : {1, 5, 32, 48}) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {1, 2}) { for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {1, 2}) { for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) { for (auto& h : {1, 3, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
if (cin == 1 && cout ==1) {
continue;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims, test_conv_fp32(dims,
weights_dim, weights_dim,
1, 1,
...@@ -428,9 +465,10 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { ...@@ -428,9 +465,10 @@ TEST(TestConv3x3s1, test_conv_3x3s1) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
} }
} }
} }
...@@ -446,21 +484,25 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { ...@@ -446,21 +484,25 @@ TEST(TestConv3x3s1, test_conv_3x3s1) {
#if 1 /// conv3x3s2 #if 1 /// conv3x3s2
TEST(TestConv3x3s2, test_conv_3x3s2) { TEST(TestConv3x3s2, test_conv_3x3s2) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) { for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 8, 32}) { for (auto& cout : {1, 3, 9, 32}) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {0, 1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {0, 1, 2}) {
for (auto& pad_top : {1, 2}) { for (auto& pad_top : {0, 1, 2}) {
for (auto& pad_bottom : {1, 2}) { for (auto& pad_bottom : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { for (auto& h : {3, 7, 15, 56, 32}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
if (cin == 1 && cout == 1) {
continue;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32(dims, test_conv_fp32(dims,
weights_dim, weights_dim,
1, 1,
...@@ -468,9 +510,10 @@ TEST(TestConv3x3s2, test_conv_3x3s2) { ...@@ -468,9 +510,10 @@ TEST(TestConv3x3s2, test_conv_3x3s2) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {1, 2, 4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
} }
} }
} }
...@@ -486,29 +529,40 @@ TEST(TestConv3x3s2, test_conv_3x3s2) { ...@@ -486,29 +529,40 @@ TEST(TestConv3x3s2, test_conv_3x3s2) {
#if 1 /// random param conv #if 1 /// random param conv
TEST(TestConvRand, test_conv_rand) { TEST(TestConvRand, test_conv_rand) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) { for (auto& cin : {1, 3, 8}) {
for (auto& cout : {1, 5, 8, 16}) { for (auto& cout : {1, 5, 16}) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) { for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad_left : {0, 1, 2}) { for (auto& pad_left : {0, 2}) {
for (auto& pad_right : {0, 1, 2}) { for (auto& pad_right : {0, 2}) {
for (auto& pad_top : {0, 1, 2}) { for (auto& pad_top : {0, 2}) {
for (auto& pad_bottom : {0, 1, 2}) { for (auto& pad_bottom : {0, 2}) {
for (auto& dila : {1, 2}) { for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; continue;
} }
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw}); DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) { for (auto& h : {1, 3, 19, 32}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
// skip 3x3 depthwise conv
if (g == cin && cin == cout && kw == 3 &&
kh == 3) {
break;
}
// skip 3x3s1 direct conv
if (g == 1 && (cin != 1 || cout != 1) &&
kw == 3 && kh == 3 && stride == 1) {
break;
}
const float leakey_relu_scale = 8.88;
test_conv_fp32( test_conv_fp32(
dims, dims,
weights_dim, weights_dim,
...@@ -517,9 +571,10 @@ TEST(TestConvRand, test_conv_rand) { ...@@ -517,9 +571,10 @@ TEST(TestConvRand, test_conv_rand) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{dila, dila}, {dila, dila},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
leakey_relu_scale);
} }
} }
} }
...@@ -551,11 +606,12 @@ TEST(TestConvCustom, test_conv_fp32_custom_size) { ...@@ -551,11 +606,12 @@ TEST(TestConvCustom, test_conv_fp32_custom_size) {
FLAGS_kernel_w}), FLAGS_kernel_w}),
FLAGS_group, FLAGS_group,
{FLAGS_stride_h, FLAGS_stride_w}, {FLAGS_stride_h, FLAGS_stride_w},
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, {FLAGS_pad_h0, FLAGS_pad_h1, FLAGS_pad_w0, FLAGS_pad_w1},
{FLAGS_dila_h, FLAGS_dila_w}, {FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias, FLAGS_flag_bias,
FLAGS_flag_relu, FLAGS_flag_act,
{FLAGS_threads}, {FLAGS_threads},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_leakey_relu_alpha);
} }
#endif // custom #endif // custom
...@@ -291,7 +291,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -291,7 +291,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads[2], pads[2],
pads[0], pads[0],
flag_bias, flag_bias,
flag_relu); static_cast<int>(flag_relu));
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32, paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8, dout_basic_int8,
scale_out.data(), scale_out.data(),
...@@ -362,6 +362,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -362,6 +362,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< pads[2] << ", " << pads[3] << pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls << ", threads: " << th << ", power_mode: " << cls
...@@ -467,7 +468,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -467,7 +468,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3}); DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 75, 32, 28}) { for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
...@@ -479,7 +480,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -479,7 +480,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
...@@ -494,14 +495,14 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -494,14 +495,14 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1}) { for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) { for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
...@@ -513,7 +514,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -513,7 +514,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
...@@ -527,8 +528,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -527,8 +528,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
#if 1 /// conv1x1s1 #if 1 /// conv1x1s1
TEST(TestConv1x1s1Int8, test_conv1x1s1) { TEST(TestConv1x1s1Int8, test_conv1x1s1) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 11, 32}) { for (auto& cin : {1, 3, 8, 32}) {
for (auto& cout : {1, 5, 16, 37}) { for (auto& cout : {1, 5, 17}) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
...@@ -538,7 +539,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -538,7 +539,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
} }
DDim weights_dim({cout, cin / g, 1, 1}); DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 32, 56, 1}) { for (auto& h : {1, 9, 16, 33}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
...@@ -550,7 +551,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -550,7 +551,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
...@@ -564,8 +565,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -564,8 +565,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
#if 1 /// conv3x3s1 #if 1 /// conv3x3s1
TEST(TestConv3x3s1Int8, test_conv_3x3s1) { TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32, 48}) { for (auto& cin : {1, 3, 8, 33}) {
for (auto& cout : {1, 5, 8, 32, 48}) { for (auto& cout : {1, 5, 33}) {
for (auto& pad_top : {1, 2}) { for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) { for (auto& pad_bottom : {1, 2}) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
...@@ -575,7 +576,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -575,7 +576,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 56, 32}) { for (auto& h : {1, 7, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
...@@ -587,7 +588,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -587,7 +588,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
...@@ -604,8 +605,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -604,8 +605,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
#if 1 /// conv3x3s2 #if 1 /// conv3x3s2
TEST(TestConv3x3s2Int8, test_conv_3x3s2) { TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) { for (auto& cin : {1, 3, 31}) {
for (auto& cout : {1, 5, 8, 32}) { for (auto& cout : {1, 5, 33}) {
for (auto& pad_top : {1, 2}) { for (auto& pad_top : {1, 2}) {
for (auto& pad_bottom : {1, 2}) { for (auto& pad_bottom : {1, 2}) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
...@@ -615,7 +616,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -615,7 +616,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { for (auto& h : {1, 7, 19, 33}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
...@@ -627,7 +628,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -627,7 +628,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
...@@ -644,8 +645,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -644,8 +645,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
#if 1 /// random param conv #if 1 /// random param conv
TEST(TestConvRandInt8, test_conv_rand) { TEST(TestConvRandInt8, test_conv_rand) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 16}) { for (auto& cin : {1, 17}) {
for (auto& cout : {1, 5, 8, 16}) { for (auto& cout : {1, 8, 17}) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& kw : {1, 2, 3}) { for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) {
...@@ -658,12 +659,12 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -658,12 +659,12 @@ TEST(TestConvRandInt8, test_conv_rand) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; break;
} }
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw}); DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) { for (auto& h : {1, 3, 5, 19}) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
...@@ -676,7 +677,7 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -676,7 +677,7 @@ TEST(TestConvRandInt8, test_conv_rand) {
{dila, dila}, {dila, dila},
flag_bias, flag_bias,
flag_relu, flag_relu,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
......
...@@ -37,7 +37,7 @@ DEFINE_int32(power_mode, ...@@ -37,7 +37,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests"); DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemv: M"); DEFINE_int32(M, 512, "gemv: M");
......
...@@ -37,7 +37,7 @@ DEFINE_int32(power_mode, ...@@ -37,7 +37,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests"); DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm_c4: M"); DEFINE_int32(M, 512, "gemm_c4: M");
......
...@@ -38,11 +38,19 @@ DEFINE_int32(K, 512, "sgemv: K"); ...@@ -38,11 +38,19 @@ DEFINE_int32(K, 512, "sgemv: K");
DEFINE_bool(traA, false, "gemv: A transpose"); DEFINE_bool(traA, false, "gemv: A transpose");
DEFINE_bool(flag_relu, false, "do relu"); DEFINE_int32(flag_act, 0, "do act");
DEFINE_bool(flag_bias, false, "with bias"); DEFINE_bool(flag_bias, false, "with bias");
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
bool test_sgemv( DEFINE_double(clipped_coef, 6.0, "clipped relu coef");
bool tra, int m, int k, bool has_bias, bool has_relu, int cls, int ths) { bool test_sgemv(bool tra,
int m,
int k,
bool has_bias,
int flag_act,
int cls,
int ths,
float six = 6.f,
float alpha = 1.f) {
Tensor ta; Tensor ta;
Tensor tb; Tensor tb;
Tensor tc; Tensor tc;
...@@ -68,8 +76,7 @@ bool test_sgemv( ...@@ -68,8 +76,7 @@ bool test_sgemv(
fill_tensor_rand(tbias, -1.f, 1.f); fill_tensor_rand(tbias, -1.f, 1.f);
LOG(INFO) << "sgemv M: " << m << ", K: " << k LOG(INFO) << "sgemv M: " << m << ", K: " << k
<< ", transA: " << (tra ? "true" : "false") << ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false"); << ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
...@@ -78,10 +85,29 @@ bool test_sgemv( ...@@ -78,10 +85,29 @@ bool test_sgemv(
auto dc = tc.mutable_data<float>(); auto dc = tc.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>(); auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity;
if (flag_act == 1) {
act = paddle::lite_api::ActivationType::kRelu;
} else if (flag_act == 2) {
act = paddle::lite_api::ActivationType::kRelu6;
} else if (flag_act == 4) {
act = paddle::lite_api::ActivationType::kLeakyRelu;
}
if (FLAGS_check_result) { if (FLAGS_check_result) {
basic_gemv( basic_gemv(m,
m, k, da, db, dbias, dc_basic, 1.f, 0.f, tra, has_bias, has_relu); k,
da,
db,
dbias,
dc_basic,
1.f,
0.f,
tra,
has_bias,
flag_act,
six,
alpha);
} }
paddle::lite::profile::Timer t0; paddle::lite::profile::Timer t0;
//! compute //! compute
...@@ -92,15 +118,37 @@ bool test_sgemv( ...@@ -92,15 +118,37 @@ bool test_sgemv(
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths); ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
/// warmup /// warmup
for (int j = 0; j < FLAGS_warmup; ++j) { for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemv( paddle::lite::arm::math::sgemv(da,
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); db,
dc,
tra,
m,
k,
has_bias,
dbias,
flag_act > 0,
act,
&ctx,
six,
alpha);
} }
t0.Reset(); t0.Reset();
for (int i = 0; i < FLAGS_repeats; ++i) { for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start(); t0.Start();
paddle::lite::arm::math::sgemv( paddle::lite::arm::math::sgemv(da,
da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); db,
dc,
tra,
m,
k,
has_bias,
dbias,
flag_act > 0,
act,
&ctx,
six,
alpha);
t0.Stop(); t0.Stop();
} }
LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls
...@@ -125,7 +173,7 @@ bool test_sgemv( ...@@ -125,7 +173,7 @@ bool test_sgemv(
tensor_diff(tc_basic, tc, tdiff); tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "basic result: "; LOG(INFO) << "basic result: ";
print_tensor(tc_basic); print_tensor(tc_basic);
LOG(INFO) << "saber result: "; LOG(INFO) << "lite result: ";
print_tensor(tc); print_tensor(tc);
LOG(INFO) << "diff result: "; LOG(INFO) << "diff result: ";
print_tensor(tdiff); print_tensor(tdiff);
...@@ -144,22 +192,31 @@ TEST(TestLiteSgemv, Sgemv) { ...@@ -144,22 +192,31 @@ TEST(TestLiteSgemv, Sgemv) {
LOG(INFO) << "run basic sgemv test"; LOG(INFO) << "run basic sgemv test";
for (auto& m : {1, 3, 8, 21, 32, 397}) { for (auto& m : {1, 3, 8, 21, 32, 397}) {
for (auto& k : {1, 3, 8, 17, 59, 234}) { for (auto& k : {1, 3, 8, 17, 59, 234}) {
for (auto& tra : {true, false}) { for (auto& tra : {false, true}) {
for (auto& has_bias : {false, true}) { for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& th : {1, 2, 4}) { for (auto& th : {1, 2, 4}) {
auto flag = test_sgemv( float six = 6.f;
tra, m, k, has_bias, has_relu, FLAGS_cluster, th); float alpha = 8.88f;
auto flag = test_sgemv(tra,
m,
k,
has_bias,
flag_act,
FLAGS_cluster,
th,
six,
alpha);
if (flag) { if (flag) {
LOG(INFO) << "test m = " << m << ", k=" << k LOG(INFO) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false") << ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false") << ", flag act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false") << ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " passed\n"; << ", threads: " << th << " passed\n";
} else { } else {
LOG(FATAL) << "test m = " << m << ", k=" << k LOG(FATAL) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false") << ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false") << ", flag_act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false") << ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " failed\n"; << ", threads: " << th << " failed\n";
} }
...@@ -180,15 +237,17 @@ TEST(TestSgemvCustom, Sgemv_custom) { ...@@ -180,15 +237,17 @@ TEST(TestSgemvCustom, Sgemv_custom) {
FLAGS_M, FLAGS_M,
FLAGS_K, FLAGS_K,
FLAGS_flag_bias, FLAGS_flag_bias,
FLAGS_flag_relu, FLAGS_flag_act,
FLAGS_cluster, FLAGS_cluster,
FLAGS_threads); FLAGS_threads,
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
if (!flag) { if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!"; << ", act: " << FLAGS_flag_act << " failed!!";
} }
LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " passed!!"; << ", act: " << FLAGS_flag_act << " passed!!";
} }
...@@ -177,7 +177,9 @@ static void basic_gemv(int m, ...@@ -177,7 +177,9 @@ static void basic_gemv(int m,
type2 beta, type2 beta,
bool trans_a = false, bool trans_a = false,
bool flag_bias = false, bool flag_bias = false,
bool flag_relu = false) { int flag_act = false,
float six = 6.f,
float leakey_relu_alpha = 1.f) {
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0); auto bias_data = static_cast<type2>(0);
...@@ -195,8 +197,15 @@ static void basic_gemv(int m, ...@@ -195,8 +197,15 @@ static void basic_gemv(int m,
sum += av * b[j]; sum += av * b[j];
} }
type2 tmp = alpha * sum + beta * c[i] + bias_data; type2 tmp = alpha * sum + beta * c[i] + bias_data;
if (flag_relu) { if (flag_act > 0) {
if (flag_act == 1) { // relu
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six;
} else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
}
} else { } else {
c[i] = tmp; c[i] = tmp;
} }
...@@ -230,7 +239,9 @@ static void conv_basic(const Dtype1* din, ...@@ -230,7 +239,9 @@ static void conv_basic(const Dtype1* din,
int pad_w, int pad_w,
int pad_h, int pad_h,
bool flag_bias, bool flag_bias,
bool flag_relu) { int act_type,
float six = 6.f,
float scale = 1.f) {
Dtype2 beta = 0; Dtype2 beta = 0;
auto src_data = din; auto src_data = din;
auto dst_data_ref = dout; auto dst_data_ref = dout;
...@@ -280,10 +291,27 @@ static void conv_basic(const Dtype1* din, ...@@ -280,10 +291,27 @@ static void conv_basic(const Dtype1* din,
} }
} }
} }
if (flag_relu) { if (act_type > 0) {
// 1-relu 2-relu6 4-leakyrelu
if (act_type == 1) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx] ? dst_data_ref[out_idx]
: (Dtype2)0; : (Dtype2)0;
} else if (act_type == 2) {
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
? dst_data_ref[out_idx]
: (Dtype2)six;
} else if (act_type == 4) {
dst_data_ref[out_idx] =
dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)(dst_data_ref[out_idx] * scale);
} else {
printf("this act type: %d does not support \n", act_type);
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册