提交 ae8cfcec 编写于 作者: C chenjiaoAngel

fix kernel choose

上级 4dbda3aa
...@@ -193,7 +193,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -193,7 +193,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
} }
#ifdef __aarch64__ #ifdef __aarch64__
#define COMPUTE_ONE_LINE_S1_PRE \ #define COMPUTE_ONE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -213,7 +213,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -213,7 +213,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_TWO_LINE_S1_PRE \ #define COMPUTE_TWO_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -243,7 +243,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -243,7 +243,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_THREE_LINE_S1_PRE \ #define COMPUTE_THREE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -283,7 +283,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -283,7 +283,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FOUR_LINE_S1_PRE \ #define COMPUTE_FOUR_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -333,7 +333,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -333,7 +333,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FIVE_LINE_S1 \ #define COMPUTE_FIVE_LINE_S1 \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -393,7 +393,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -393,7 +393,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_ONE_LINE_S1_POST \ #define COMPUTE_ONE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -413,7 +413,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -413,7 +413,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_TWO_LINE_S1_POST \ #define COMPUTE_TWO_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -443,7 +443,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -443,7 +443,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_THREE_LINE_S1_POST \ #define COMPUTE_THREE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -483,7 +483,7 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -483,7 +483,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FOUR_LINE_S1_POST \ #define COMPUTE_FOUR_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \ "ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \ "ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \ "ext v11.16b, v9.16b, v10.16b, #4\n" \
...@@ -533,25 +533,25 @@ void conv_depthwise_5x5s1_fp32(float* dout, ...@@ -533,25 +533,25 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \ "ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n" "fadd v16.4s, v14.4s, v15.4s\n"
#define RESULT_S1 \ #define RESULT_S1 \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \ "st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_RELU \ #define RESULT_S1_RELU \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"fmax v16.4s, v16.4s, %[vzero]]\n" \ "fmax v16.4s, v16.4s, %[vzero].4s\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \ "st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_RELU6 \ #define RESULT_S1_RELU6 \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"fmax v16.4s, v16.4s, %[vzero]]\n" \ "fmax v16.4s, v16.4s, %[vzero].4s\n" \
"fmin v16.4s, v16.4s, %[vsix]]\n" \ "fmin v16.4s, v16.4s, %[vsix].4s\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \ "st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b" "bne 1b"
#define RESULT_S1_LEAKY_RELU \ #define RESULT_S1_LEAKY_RELU \
"ld1 {v15.4s}, [%[bias]], #16\n" \ "ld1 {v15.4s}, [%[bias]]\n" \
"fcmge v17.4s, v16.4s, %[vzero].4s\n" \ "fcmge v17.4s, v16.4s, %[vzero].4s\n" \
"fmul v18.4s, v16.4s, %[vscale].4s\n" \ "fmul v18.4s, v16.4s, %[vscale].4s\n" \
"bif v16.4s, v18.4s, v17.4s\n" \ "bif v16.16b, v18.16b, v17.16b\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \ "st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b" "bne 1b"
#else #else
...@@ -1213,7 +1213,7 @@ inline void compute_all_padding_mid(float* dout, ...@@ -1213,7 +1213,7 @@ inline void compute_all_padding_mid(float* dout,
} }
// mid // mid
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64_ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1 asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
...@@ -1611,11 +1611,6 @@ void conv_depthwise_5x5s1_bias(float* dout, ...@@ -1611,11 +1611,6 @@ void conv_depthwise_5x5s1_bias(float* dout,
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// din_ptr_arr[0] = din_ptr0;
// din_ptr_arr[1] = din_ptr1;
// din_ptr_arr[2] = din_ptr2;
// din_ptr_arr[3] = din_ptr3;
// din_ptr_arr[4] = din_ptr4;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left, compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
...@@ -1926,7 +1921,7 @@ inline void compute_all_padding_mid_relu(float* dout, ...@@ -1926,7 +1921,7 @@ inline void compute_all_padding_mid_relu(float* dout,
*dout++ = sum > 0.f ? sum : 0.f; *dout++ = sum > 0.f ? sum : 0.f;
} }
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64_ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
...@@ -2335,11 +2330,6 @@ void conv_depthwise_5x5s1_bias_relu(float* dout, ...@@ -2335,11 +2330,6 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
din_ptr_arr[3] = din_ptr3; din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4; din_ptr_arr[4] = din_ptr4;
} }
// din_ptr_arr[0] = din_ptr0;
// din_ptr_arr[1] = din_ptr1;
// din_ptr_arr[2] = din_ptr2;
// din_ptr_arr[3] = din_ptr3;
// din_ptr_arr[4] = din_ptr4;
// mid_h // mid_h
for (int h = 0; h < loop_h; h++) { for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left, compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
...@@ -2667,7 +2657,7 @@ inline void compute_all_padding_mid_relu6(float* dout, ...@@ -2667,7 +2657,7 @@ inline void compute_all_padding_mid_relu6(float* dout,
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
} }
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64_ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6 asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
...@@ -3178,7 +3168,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3178,7 +3168,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_ONE_LINE_S1_PRE RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_ONE_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3224,7 +3216,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3224,7 +3216,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_TWO_LINE_S1_PRE RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_TWO_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3275,7 +3269,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3275,7 +3269,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_THREE_LINE_S1_PRE RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_THREE_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3330,7 +3326,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout, ...@@ -3330,7 +3326,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_FOUR_LINE_S1_PRE RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_FOUR_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3421,7 +3419,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -3421,7 +3419,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
*dout++ = sum > 0.f ? sum : sum * scale[0]; *dout++ = sum > 0.f ? sum : sum * scale[0];
} }
if (cnt > 0) { if (cnt > 0) {
#ifdef __aarch64_ #ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]), [din_ptr0] "+r"(din_ptr_arr[0]),
...@@ -3449,7 +3447,9 @@ inline void compute_all_padding_mid_leakyRelu(float* dout, ...@@ -3449,7 +3447,9 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3560,7 +3560,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -3560,7 +3560,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3606,7 +3608,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -3606,7 +3608,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3656,7 +3660,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -3656,7 +3660,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
...@@ -3710,7 +3716,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout, ...@@ -3710,7 +3716,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16"); "v16",
"v17",
"v18");
#else #else
asm volatile(COMPUTE_FOUR_LINE_S1_POST RESULT_S1_LEAKY_RELU asm volatile(COMPUTE_FOUR_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
......
...@@ -751,43 +751,43 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -751,43 +751,43 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param, act_param,
ctx); ctx);
} else if (stride == 1) { } else if (stride == 1) {
#if 0 if (h_in < 5 || w_in < 5) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout), conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din), reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
flag_bias, flag_bias,
flag_relu, flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
pad_w, pad_w,
pad_h, pad_h,
param, param,
ctx); ctx);
#else } else {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout), conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din), reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
flag_bias, flag_bias,
flag_relu, flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
paddings[0], paddings[0],
paddings[1], paddings[1],
paddings[2], paddings[2],
paddings[3], paddings[3],
act_param, act_param,
ctx); ctx);
#endif }
} else { } else {
LOG(FATAL) << "unsupport this type 5x5 dw conv"; LOG(FATAL) << "unsupport this type 5x5 dw conv";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册