提交 443a7380 编写于 作者: C chenjiaoAngel

fix test

上级 047a22cb
......@@ -615,7 +615,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"1: \n" \
"subs %[cnt], #1\n" \
"vmla.f32 q15, q8, %e[wr0][0]\n" /*0123*wr0[0]*/ \
"vmul.f32 q14, q9, %e[wr5][1]\n" /*4567*wr5[2]*/ \
"vmul.f32 q14, q9, %f[wr5][0]\n" /*4567*wr5[2]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr1]]!\n" \
"vmla.f32 q15, q10, %e[wr0][1]\n" /*1234*wr0[1]*/\
"vld1.f32 {d18-d19}, [%[din_ptr1]]\n" \
......@@ -626,7 +626,7 @@ void conv_depthwise_5x5s1_fp32(float* dout,
"vext.32 q12, q8, q9, #3\n" \
"vmla.f32 q14, q8, %e[wr1][0]\n" /*0123*wr1[0]*/ \
"vld1.f32 {d16-d17}, [%[din_ptr2]]!\n" \
"vmla.f32 q15, q9, %f[wr5][0]\n" /*4567*wr5[3]*/ \
"vmla.f32 q15, q9, %f[wr5][1]\n" /*4567*wr5[3]*/ \
"vld1.f32 {d18-d19}, [%[din_ptr2]]\n" \
"vmla.f32 q14, q10, %e[wr1][1]\n" /*1234*wr1[1]*/\
"vext.32 q10, q8, q9, #1\n" \
......@@ -945,7 +945,7 @@ inline float compute_one_data_post(const float* data, float32x4_t wr, float bias
inline void compute_all_padding_pre(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -959,6 +959,7 @@ inline void compute_all_padding_pre(float* dout,
for (int w = pad_left; w > 4; w--) {
*dout++ = bias[0];
}
LOG(INFO) << "pad_left_new: " << pad_left_new;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i);
for (int k = 0; k < num; k++) {
......@@ -1158,6 +1159,7 @@ inline void compute_all_padding_pre(float* dout,
default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support";
}
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
......@@ -1169,44 +1171,17 @@ inline void compute_all_padding_pre(float* dout,
}
*dout++ = sum;
}
LOG(INFO) << " pad_right_new: " << pad_right_new;
// right
for (int i = 1; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][4 - i], 4 - i);
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][4 - i], 4 - i);
sum += compute_one_data_post(din_ptr_arr[num - 1 - k], weights[3 - k], 0.f, weights[3 - k][3 - i], 3 - i);
din_ptr_arr[num - 1 - k]++;
}
*dout++ = sum;
}
/*
switch (pad_right_new) {
case 1:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][3], 3);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][3], 3);
}
*dout++ = sum;
case 2:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][2], 2);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][2], 2);
}
*dout++ = sum;
case 3:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][1], 1);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][1], 1);
}
*dout++ = sum;
case 4:
float sum = compute_one_data_post(din_ptr_arr[num], weights[4], bias[0], weights[4][0], 0);
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[num - 1 - i], weights[3 - i], 0.f, weights[3 - i][0], 0);
}
*dout++ = sum;
}
*/
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0];
}
......@@ -1215,7 +1190,7 @@ inline void compute_all_padding_pre(float* dout,
inline void compute_all_padding_mid(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -1293,7 +1268,8 @@ inline void compute_all_padding_mid(float* dout,
"q13",
"q14",
"q15");
#endif
#endif
din_ptr_arr[0] -= 4;
}
// remain
for (int w = 0; w < remain; w++) {
......@@ -1305,15 +1281,17 @@ inline void compute_all_padding_mid(float* dout,
}
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[num]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[tmp - k]++;
}
*dout++ = sum;
}
for (int w = pad_right; w > 4; w--) {
*dout++ = bias[0];
}
......@@ -1321,7 +1299,7 @@ inline void compute_all_padding_mid(float* dout,
inline void compute_all_padding_post(float* dout,
const float** din_ptr_arr,
const float* bias,
std::vector<float32x4_t> weights,
float32x4_t* weights,
int win,
int wout,
int pad_left,
......@@ -1337,9 +1315,9 @@ inline void compute_all_padding_post(float* dout,
}
int tmp = num - 1;
for (int i = pad_left_new; i > 0; i--) {
float sum = compute_one_data_pre(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4 - i);
float sum = compute_one_data_pre(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i);
for (int k = 0; k < num; k++) {
sum += compute_one_data_pre(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
sum += compute_one_data_pre(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[5][tmp - k], 4 - i);
}
*dout++ = sum;
}
......@@ -1350,7 +1328,7 @@ inline void compute_all_padding_post(float* dout,
#ifdef __aarch64__
asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr0] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -1368,7 +1346,7 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_ONE_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr0] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr5] "w"(weights[5]),
......@@ -1384,13 +1362,14 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[3] -= 4;
break;
case 1:
#ifdef __aarch64__
asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1409,8 +1388,8 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_TWO_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr0] "+r"(din_ptr_arr[2]),
[din_ptr1] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1427,14 +1406,15 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[2] -= 4;
break;
case 2:
#ifdef __aarch64__
asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1454,9 +1434,9 @@ inline void compute_all_padding_post(float* dout,
#else
asm volatile(COMPUTE_THREE_LINE_S1_POST RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr_arr[0]),
[din_ptr1] "+r"(din_ptr_arr[1]),
[din_ptr2] "+r"(din_ptr_arr[2]),
[din_ptr0] "+r"(din_ptr_arr[1]),
[din_ptr1] "+r"(din_ptr_arr[2]),
[din_ptr2] "+r"(din_ptr_arr[3]),
[dout_ptr] "+r"(dout)
: [wr0] "w"(weights[0]),
[wr1] "w"(weights[1]),
......@@ -1474,6 +1454,7 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[1] -= 4;
break;
case 3:
#ifdef __aarch64__
......@@ -1525,6 +1506,7 @@ inline void compute_all_padding_post(float* dout,
"q14",
"q15");
#endif
din_ptr_arr[0] -= 4;
break;
default:
LOG(FATAL) << "This num: " << (num + 1) << "does not support";
......@@ -1532,20 +1514,22 @@ inline void compute_all_padding_post(float* dout,
}
// remain
for (int w = 0; w < remain; w++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[num]++;
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4);
din_ptr_arr[3]++;
for (int i = 0; i < num; i++) {
sum += compute_one_data_post(din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[tmp - i]++;
sum += compute_one_data_post(din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4);
din_ptr_arr[2 - i]++;
}
*dout++ = sum;
}
// right
for (int i = 0; i < pad_right_new; i++) {
float sum = compute_one_data_post(din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i);
float sum = compute_one_data_post(din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i);
din_ptr_arr[3]++;
for (int k = 0; k < num; k++) {
sum += compute_one_data_post(din_ptr_arr[tmp - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
sum += compute_one_data_post(din_ptr_arr[2 - k], weights[tmp - k], 0.f, weights[tmp - k][3 - i], 3 - i);
din_ptr_arr[2 - k]++;
}
*dout++ = sum;
}
......@@ -1572,6 +1556,8 @@ void conv_depthwise_5x5s1_bias(float* dout,
ARMContext* ctx){
int loop_w = wout - pad_left - pad_right;
int loop_h = hout - pad_top - pad_bottom;
LOG(INFO) << "pad_top: " << pad_top << ", pad_bottom: " << pad_bottom;
LOG(INFO) << "pad_left: " << pad_left << ", pad_right: " << pad_right;
int in_size = win * hin;
int out_size = wout * hout;
int cnt = loop_w >> 2;
......@@ -1612,32 +1598,54 @@ void conv_depthwise_5x5s1_bias(float* dout,
wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3);
wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0);
const float* din_ptr_arr[] = {din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4};
float32x4_t wei_vwc[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h
float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6};
// top_h
for (int h = pad_top; h > 4; h--) {
memset(dout_ptr, bias[0], sizeof(float)*wout);
dout_ptr += wout;
}
for (int h = pad_top_new; h > 0; h--) {
compute_all_padding_pre(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4 - h);
pad_right, pad_left_new, pad_right_new, cnt, remain, 4 - h);
dout_ptr += wout;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
// mid_h
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
for (int h = 0; h < loop_h; h++) {
compute_all_padding_mid(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 4);
pad_right, pad_left_new, pad_right_new, cnt, remain, 4);
dout_ptr += wout;
for (int i = 0; i < 4; i++) {
din_ptr_arr[i] = din_ptr_arr[i + 1];
}
din_ptr_arr[4] += win;
din_ptr0 = din_ptr1;
din_ptr1 = din_ptr2;
din_ptr2 = din_ptr3;
din_ptr3 = din_ptr4;
din_ptr4 = din_ptr4 + win;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
// bottom
for (int h = 0; h < pad_bottom_new; h++) {
for (int i = 0; i < 5; i++) LOG(INFO) << "i: " << i << ", ptr: " << din_ptr_arr[i];
LOG(INFO) << "num: " << (3 -h);
compute_all_padding_post(dout_ptr, din_ptr_arr, vbias, weights_vec, win, wout, pad_left,
pad_left_new, pad_right, pad_right_new, cnt, remain, 3 - h);
pad_right, pad_left_new, pad_right_new, cnt, remain, 3 - h);
dout_ptr += wout;
din_ptr_arr[0] = din_ptr0;
din_ptr_arr[1] = din_ptr1;
din_ptr_arr[2] = din_ptr2;
din_ptr_arr[3] = din_ptr3;
din_ptr_arr[4] = din_ptr4;
}
}
}
......
......@@ -751,7 +751,7 @@ void conv_depthwise_5x5_fp32(const void* din,
act_param,
ctx);
} else if (stride == 1) {
#if 1
#if 0
conv_depthwise_5x5s1_fp32(reinterpret_cast<float*>(dout),
reinterpret_cast<const float*>(din),
reinterpret_cast<const float*>(weights),
......
......@@ -56,8 +56,16 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
} else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) {
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
if (win >= kw && hin >= kw && (strides[0] == 1 && strides[1] == 1)) {
flag_trans_weights_ = false;
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_fp32";
#endif
} else if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) {
// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册