From 24efaf4194ca34e9b7d2f9f42ed932f417f11fa8 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Tue, 14 Jul 2020 17:54:35 +0800 Subject: [PATCH] fix compute error, test=develop --- lite/backends/arm/math/conv_block_utils.h | 29 +++++++++++++++-------- lite/kernels/arm/sequence_conv_compute.cc | 7 ------ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 8e108c150f..f83fa71bd7 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -160,7 +160,6 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { float32x4_t din1 = vld1q_f32(din_ptr1); float32x4_t din2 = vld1q_f32(din_ptr2); float32x4_t din3 = vld1q_f32(din_ptr3); - dout_ptr0 += nn_num; Dtype* dout_ptr1 = dout_ptr0 + n; Dtype* dout_ptr2 = dout_ptr1 + n; Dtype* dout_ptr3 = dout_ptr2 + n; @@ -171,17 +170,27 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { din_ptr0 += 4; din_ptr1 += 4; // a00 b00 c00 d00 a02 b02 c02 d02 - float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); // a01 b01 c01 d01 a03 b03 c03 d03 - float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); + float tmp_val1 = tmp0.val[0][2]; + float tmp_val2 = tmp0.val[0][3]; + tmp0.val[0][2] = tmp2.val[0][0]; + tmp0.val[0][3] = tmp2.val[0][1]; + float tmp_val3 = tmp0.val[1][2]; + float tmp_val4 = tmp0.val[1][3]; + tmp2.val[0][0] = tmp_val1; + tmp2.val[0][1] = tmp_val2; + tmp0.val[1][2] = tmp2.val[1][0]; + tmp0.val[1][3] = tmp2.val[1][1]; + tmp2.val[1][0] = tmp_val3; + tmp2.val[1][1] = tmp_val4; din_ptr2 += 4; din_ptr3 += 4; - vst1q_f32(dout_ptr0, tmp00.val[0]); - vst1q_f32(dout_ptr1, tmp02.val[0]); - vst1q_f32(dout_ptr2, tmp00.val[1]); - vst1q_f32(dout_ptr3, tmp02.val[1]); + vst1q_f32(dout_ptr0, tmp0.val[0]); + vst1q_f32(dout_ptr1, tmp0.val[1]); + dout_ptr0 += nn_num; + vst1q_f32(dout_ptr2, tmp2.val[0]); + vst1q_f32(dout_ptr3, tmp2.val[1]); } - dout_ptr0 += nn_num; for (int y = 0; y < remain_m; y++) { *dout_ptr0++ = *din_ptr0++; *dout_ptr0++ = *din_ptr1++; @@ -190,21 +199,21 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { } } const Dtype* din_ptr0 = din + cnt_n * mm_num; + dout = dout + cnt_n * 4; for (int x = 0; x < remain_n; x++) { Dtype* dout_ptr0 = dout + x * 4; for (int y = 0; y < cnt_m; y++) { float32x4_t din0 = vld1q_f32(din_ptr0); - dout_ptr0 += nn_num; Dtype* dout_ptr1 = dout_ptr0 + n; Dtype* dout_ptr2 = dout_ptr1 + n; Dtype* dout_ptr3 = dout_ptr2 + n; din_ptr0 += 4; *dout_ptr0 = din0[0]; *dout_ptr1 = din0[1]; + dout_ptr0 += nn_num; *dout_ptr2 = din0[2]; *dout_ptr3 = din0[3]; } - dout_ptr0 += nn_num; for (int y = 0; y < remain_m; y++) { *dout_ptr0++ = *din_ptr0++; } diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc index 71d1f747c8..71826ae673 100644 --- a/lite/kernels/arm/sequence_conv_compute.cc +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -102,17 +102,10 @@ void SequenceConvCompute::Run() { 1, 1, // stride_h, stride_w, dilation_h, dilation_w tmp_data); -#if 0 local_naive_transpose(tmp_data, sub_col_data, kernel_size * hidden_dim, input_row_end - input_row_begin); -#else - paddle::lite::arm::math::transpose(tmp_data, - sub_col_data, - kernel_size * hidden_dim, - input_row_end - input_row_begin); -#endif } } -- GitLab