diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 9fa3a6314481dbd1c12bb63976cb25e2592b9904..8e108c150fc1b403c09140fdbc8ead5b30cd7981 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -147,8 +147,8 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { int remain_n = n & 3; int cnt_m = m >> 2; int remain_m = m & 3; - int nn_num = n << 2; // n * 4 - int mm_num = m << 2; // m * 4 + int nn_num = n << 2; // n * 4 + int mm_num = m << 2; // m * 4 for (int x = 0; x < cnt_n; x++) { const Dtype* din_ptr0 = din + x * mm_num; const Dtype* din_ptr1 = din_ptr0 + m; @@ -156,7 +156,7 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { const Dtype* din_ptr3 = din_ptr2 + m; Dtype* dout_ptr0 = dout + x * 4; for (int y = 0; y < cnt_m; y++) { - float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03 + float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03 float32x4_t din1 = vld1q_f32(din_ptr1); float32x4_t din2 = vld1q_f32(din_ptr2); float32x4_t din3 = vld1q_f32(din_ptr3); @@ -164,16 +164,16 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { Dtype* dout_ptr1 = dout_ptr0 + n; Dtype* dout_ptr2 = dout_ptr1 + n; Dtype* dout_ptr3 = dout_ptr2 + n; - float32x4x2_t tmp0 = vtrnq_f32(din0, din1); // a00 b00 a02 b02 - //float32x4_t tmp1 = vtrn2q_f32(din0, din1); // a01 b01 a03 b03 - float32x4x2_t tmp2 = vtrnq_f32(din2, din3); // c00 d00 c02 d02 - // float32x4_t tmp3 = vtrn2q_f32(din2, din3); // c01 d01 c03 d03 + // a00 b00 a02 b02 a01 b01 a03 b03 + float32x4x2_t tmp0 = vtrnq_f32(din0, din1); + // c00 d00 c02 d02 c01 d01 c03 d03 + float32x4x2_t tmp2 = vtrnq_f32(din2, din3); din_ptr0 += 4; din_ptr1 += 4; - float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); // a00 b00 c00 d00 - // float32x4_t tmp01 = vtrn2q_f32(tmp0, tmp2); // a02 b02 c02 d02 - float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); // a01 b01 c01 d01 - // float32x4_t tmp03 = vtrn2q_f32(tmp1, tmp3); // a03 b03 c03 d03 + // a00 b00 c00 d00 a02 b02 c02 d02 + float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); + // a01 b01 c01 d01 a03 b03 c03 d03 + float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); din_ptr2 += 4; din_ptr3 += 4; vst1q_f32(dout_ptr0, tmp00.val[0]); @@ -193,7 +193,7 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) { for (int x = 0; x < remain_n; x++) { Dtype* dout_ptr0 = dout + x * 4; for (int y = 0; y < cnt_m; y++) { - float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03 + float32x4_t din0 = vld1q_f32(din_ptr0); dout_ptr0 += nn_num; Dtype* dout_ptr1 = dout_ptr0 + n; Dtype* dout_ptr2 = dout_ptr1 + n; diff --git a/lite/backends/arm/math/sequence_pool.cc b/lite/backends/arm/math/sequence_pool.cc index 91a82bf96287ae38efc601093b0f69b5f838114b..bd42330c2bee63794ee808167ed8bfee5364a0d3 100644 --- a/lite/backends/arm/math/sequence_pool.cc +++ b/lite/backends/arm/math/sequence_pool.cc @@ -54,7 +54,7 @@ void seq_pool_sum(const float* din, din_ptr += width; } #else - int cnt_w = height >> 2; + int cnt_w = height >> 2; int remain_w = height & 3; int cnt_h = height >> 2; int remain_h = height & 3; @@ -211,62 +211,62 @@ void seq_pool_max(const float* din, din_ptr += width; } #else - int cnt_w = height >> 2; - int remain_w = height & 3; - int cnt_h = height >> 2; - int remain_h = height & 3; - int stride = width << 2; - for (int w = 0; w < cnt_w; w++) { - const float* din_ptr0 = din_ptr + w * 4; - float32x4_t dout_val = vld1q_f32(dout_ptr); - const float* din_ptr1 = din_ptr0 + width; - const float* din_ptr2 = din_ptr1 + width; - const float* din_ptr3 = din_ptr2 + width; - for (int h = 0; h < cnt_h; h++) { - float32x4_t din0 = vld1q_f32(din_ptr0); - float32x4_t din1 = vld1q_f32(din_ptr1); - float32x4_t din2 = vld1q_f32(din_ptr2); - float32x4_t din3 = vld1q_f32(din_ptr3); - dout_val = vmaxq_f32(din0, dout_val); - float32x4_t tmp = vmaxq_f32(din1, din2); - din_ptr0 += stride; - din_ptr1 += stride; - dout_val = vmaxq_f32(din3, dout_val); - din_ptr2 += stride; - din_ptr3 += stride; - dout_val = vmaxq_f32(tmp, dout_val); - } - for (int h = 0; h < remain_h; h++) { - float32x4_t din0 = vld1q_f32(din_ptr0); - dout_val = vmaxq_f32(din0, dout_val); - din_ptr0 += width; - } - vst1q_f32(dout_ptr, dout_val); - dout_ptr += 4; - } - const float* din_ptr00 = din_ptr + cnt_w * 4; - for (int w = 0; w < remain_w; w++) { - const float* din_ptr0 = din_ptr00 + w; - const float* din_ptr1 = din_ptr0 + width; - const float* din_ptr2 = din_ptr1 + width; - const float* din_ptr3 = din_ptr2 + width; - for (int h = 0; h < cnt_h; h++) { - *dout_ptr += din_ptr0[0]; - *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); - float tmp = std::max(din_ptr1[0], din_ptr2[0]); - din_ptr0 += stride; - din_ptr1 += stride; - *dout_ptr = std::max(*dout_ptr, din_ptr3[0]); - din_ptr2 += stride; - din_ptr3 += stride; - *dout_ptr = std::max(*dout_ptr, tmp); + int cnt_w = height >> 2; + int remain_w = height & 3; + int cnt_h = height >> 2; + int remain_h = height & 3; + int stride = width << 2; + for (int w = 0; w < cnt_w; w++) { + const float* din_ptr0 = din_ptr + w * 4; + float32x4_t dout_val = vld1q_f32(dout_ptr); + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + float32x4_t din1 = vld1q_f32(din_ptr1); + float32x4_t din2 = vld1q_f32(din_ptr2); + float32x4_t din3 = vld1q_f32(din_ptr3); + dout_val = vmaxq_f32(din0, dout_val); + float32x4_t tmp = vmaxq_f32(din1, din2); + din_ptr0 += stride; + din_ptr1 += stride; + dout_val = vmaxq_f32(din3, dout_val); + din_ptr2 += stride; + din_ptr3 += stride; + dout_val = vmaxq_f32(tmp, dout_val); + } + for (int h = 0; h < remain_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + dout_val = vmaxq_f32(din0, dout_val); + din_ptr0 += width; + } + vst1q_f32(dout_ptr, dout_val); + dout_ptr += 4; } - for (int h = 0; h < remain_h; h++) { - *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); - din_ptr0 += width; + const float* din_ptr00 = din_ptr + cnt_w * 4; + for (int w = 0; w < remain_w; w++) { + const float* din_ptr0 = din_ptr00 + w; + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + *dout_ptr += din_ptr0[0]; + *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); + float tmp = std::max(din_ptr1[0], din_ptr2[0]); + din_ptr0 += stride; + din_ptr1 += stride; + *dout_ptr = std::max(*dout_ptr, din_ptr3[0]); + din_ptr2 += stride; + din_ptr3 += stride; + *dout_ptr = std::max(*dout_ptr, tmp); + } + for (int h = 0; h < remain_h; h++) { + *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); + din_ptr0 += width; + } + dout_ptr++; } - dout_ptr++; - } #endif } } diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc index ef29ac2318695fe4ea687ad4d1877423d2b6c0ae..b7432659abec57f6c4c375488930b3bdc4438bf2 100644 --- a/lite/kernels/arm/sequence_conv_compute.cc +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -17,8 +17,8 @@ limitations under the License. */ #include #include #include -#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/sgemm.h" #include "lite/core/op_registry.h" #include "lite/core/tensor.h" @@ -108,13 +108,12 @@ void SequenceConvCompute::Run() { kernel_size * hidden_dim, input_row_end - input_row_begin); #else - paddle::lite::arm::math::transpose( - tmp_data, - sub_col_data, - kernel_size * hidden_dim, - input_row_end - input_row_begin); -#endif - } + paddle::lite::arm::math::transpose(tmp_data, + sub_col_data, + kernel_size * hidden_dim, + input_row_end - input_row_begin); +#endif + } } // SGDMM C := alpha * A * B + beta * C