提交 ae8cfcec 编写于 作者: C chenjiaoAngel

fix kernel choose

上级 4dbda3aa
......@@ -193,7 +193,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
}
#ifdef __aarch64__
#define COMPUTE_ONE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -213,7 +213,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_TWO_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -243,7 +243,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_THREE_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -283,7 +283,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FOUR_LINE_S1_PRE \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -333,7 +333,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FIVE_LINE_S1 \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -393,7 +393,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_ONE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -413,7 +413,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_TWO_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -443,7 +443,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_THREE_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -483,7 +483,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define COMPUTE_FOUR_LINE_S1_POST \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"ld1 {v9.4s}, [%[din_ptr0]], #16\n" \
"ld1 {v10.4s}, [%[din_ptr0]]\n" \
"ext v11.16b, v9.16b, v10.16b, #4\n" \
......@@ -533,25 +533,25 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"ext v13.16b, v9.16b, v10.16b, #12\n" \
"fadd v16.4s, v14.4s, v15.4s\n"
#define RESULT_S1 \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#define RESULT_S1_RELU \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"fmax v16.4s, v16.4s, %[vzero]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fmax v16.4s, v16.4s, %[vzero].4s\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#define RESULT_S1_RELU6 \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"fmax v16.4s, v16.4s, %[vzero]]\n" \
"fmin v16.4s, v16.4s, %[vsix]]\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fmax v16.4s, v16.4s, %[vzero].4s\n" \
"fmin v16.4s, v16.4s, %[vsix].4s\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#define RESULT_S1_LEAKY_RELU \
"ld1 {v15.4s}, [%[bias]], #16\n" \
"ld1 {v15.4s}, [%[bias]]\n" \
"fcmge v17.4s, v16.4s, %[vzero].4s\n" \
"fmul v18.4s, v16.4s, %[vscale].4s\n" \
"bif v16.4s, v18.4s, v17.4s\n" \
"bif v16.16b, v18.16b, v17.16b\n" \
"st1 {v16.4s}, [%[dout_ptr]], #16\n" \
"bne 1b"
#else
......@@ -1213,7 +1213,7 @@ inline void compute_all_padding_mid(float* dout,
}
// mid
if (cnt > 0) {
#ifdef __aarch64_
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
......@@ -1611,11 +1611,6 @@ void conv_depthwise_5x5s1_bias(float* dout,
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
// din_ptr_arr[0] = din_ptr0;
// din_ptr_arr[1] = din_ptr1;
// din_ptr_arr[2] = din_ptr2;
// din_ptr_arr[3] = din_ptr3;
// din_ptr_arr[4] = din_ptr4;
// mid_h
for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
......@@ -1926,7 +1921,7 @@ inline void compute_all_padding_mid_relu(float* dout,
*dout++ = sum > 0.f ? sum : 0.f;
}
if (cnt > 0) {
#ifdef __aarch64_
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
......@@ -2335,11 +2330,6 @@ void conv_depthwise_5x5s1_bias_relu(float* dout,
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
// din_ptr_arr[0] = din_ptr0;
// din_ptr_arr[1] = din_ptr1;
// din_ptr_arr[2] = din_ptr2;
// din_ptr_arr[3] = din_ptr3;
// din_ptr_arr[4] = din_ptr4;
// mid_h
for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid_relu(dout_ptr, din_ptr_arr, vbias, weights_vec, vzero, win, wout, pad_left,
......@@ -2667,7 +2657,7 @@ inline void compute_all_padding_mid_relu6(float* dout,
*dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f;
}
if (cnt > 0) {
#ifdef __aarch64_
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
......@@ -3178,7 +3168,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_ONE_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3224,7 +3216,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_TWO_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3275,7 +3269,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_THREE_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3330,7 +3326,9 @@ inline void compute_all_padding_pre_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_FOUR_LINE_S1_PRE RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3421,7 +3419,7 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
*dout++ = sum > 0.f ? sum : sum * scale[0];
}
if (cnt > 0) {
#ifdef __aarch64_
#ifdef __aarch64__
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
......@@ -3449,7 +3447,9 @@ inline void compute_all_padding_mid_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_FIVE_LINE_S1 RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3560,7 +3560,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3606,7 +3608,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3656,7 +3660,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......@@ -3710,7 +3716,9 @@ inline void compute_all_padding_post_leakyRelu(float* dout,
"v13",
"v14",
"v15",
"v16");
"v16",
"v17",
"v18");
#else
asm volatile(COMPUTE_FOUR_LINE_S1_POST RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
......
......@@ -751,43 +751,43 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param,
ctx);
} else if (stride == 1) {
#if 0
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
param,
ctx);
#else
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
act_param,
ctx);
#endif
if (h_in < 5 || w_in < 5) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
param,
ctx);
} else {
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
paddings[0],
paddings[1],
paddings[2],
paddings[3],
act_param,
ctx);
}
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册