From 13a95dc65cdb1c27d4dd6979a747d85d957baae4 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Mon, 13 Jul 2020 17:15:48 +0800 Subject: [PATCH] uupdate sequence_pool and sequence_conv profiler, test=develop --- lite/backends/arm/math/conv_block_utils.h | 71 +++++++++++++ lite/backends/arm/math/sequence_pool.cc | 121 +++++++++++++++++++++- lite/kernels/arm/sequence_conv_compute.cc | 11 +- 3 files changed, 200 insertions(+), 3 deletions(-) diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 9625b1cc03..9fa3a63144 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -139,6 +139,77 @@ static bool conv_trans_weights_numc(const dtype* din, } return true; } +template +void transpose(const Dtype* din, Dtype* dout, int m, int n) { + // nxm == mxn + // 4x4 分块处理 + int cnt_n = n >> 2; + int remain_n = n & 3; + int cnt_m = m >> 2; + int remain_m = m & 3; + int nn_num = n << 2; // n * 4 + int mm_num = m << 2; // m * 4 + for (int x = 0; x < cnt_n; x++) { + const Dtype* din_ptr0 = din + x * mm_num; + const Dtype* din_ptr1 = din_ptr0 + m; + const Dtype* din_ptr2 = din_ptr1 + m; + const Dtype* din_ptr3 = din_ptr2 + m; + Dtype* dout_ptr0 = dout + x * 4; + for (int y = 0; y < cnt_m; y++) { + float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03 + float32x4_t din1 = vld1q_f32(din_ptr1); + float32x4_t din2 = vld1q_f32(din_ptr2); + float32x4_t din3 = vld1q_f32(din_ptr3); + dout_ptr0 += nn_num; + Dtype* dout_ptr1 = dout_ptr0 + n; + Dtype* dout_ptr2 = dout_ptr1 + n; + Dtype* dout_ptr3 = dout_ptr2 + n; + float32x4x2_t tmp0 = vtrnq_f32(din0, din1); // a00 b00 a02 b02 + //float32x4_t tmp1 = vtrn2q_f32(din0, din1); // a01 b01 a03 b03 + float32x4x2_t tmp2 = vtrnq_f32(din2, din3); // c00 d00 c02 d02 + // float32x4_t tmp3 = vtrn2q_f32(din2, din3); // c01 d01 c03 d03 + din_ptr0 += 4; + din_ptr1 += 4; + float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); // a00 b00 c00 d00 + // float32x4_t tmp01 = vtrn2q_f32(tmp0, tmp2); // a02 b02 c02 d02 + float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); // a01 b01 c01 d01 + // float32x4_t tmp03 = vtrn2q_f32(tmp1, tmp3); // a03 b03 c03 d03 + din_ptr2 += 4; + din_ptr3 += 4; + vst1q_f32(dout_ptr0, tmp00.val[0]); + vst1q_f32(dout_ptr1, tmp02.val[0]); + vst1q_f32(dout_ptr2, tmp00.val[1]); + vst1q_f32(dout_ptr3, tmp02.val[1]); + } + dout_ptr0 += nn_num; + for (int y = 0; y < remain_m; y++) { + *dout_ptr0++ = *din_ptr0++; + *dout_ptr0++ = *din_ptr1++; + *dout_ptr0++ = *din_ptr2++; + *dout_ptr0++ = *din_ptr3++; + } + } + const Dtype* din_ptr0 = din + cnt_n * mm_num; + for (int x = 0; x < remain_n; x++) { + Dtype* dout_ptr0 = dout + x * 4; + for (int y = 0; y < cnt_m; y++) { + float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03 + dout_ptr0 += nn_num; + Dtype* dout_ptr1 = dout_ptr0 + n; + Dtype* dout_ptr2 = dout_ptr1 + n; + Dtype* dout_ptr3 = dout_ptr2 + n; + din_ptr0 += 4; + *dout_ptr0 = din0[0]; + *dout_ptr1 = din0[1]; + *dout_ptr2 = din0[2]; + *dout_ptr3 = din0[3]; + } + dout_ptr0 += nn_num; + for (int y = 0; y < remain_m; y++) { + *dout_ptr0++ = *din_ptr0++; + } + } +} /*preprocessing inputs * input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws] * n = he - hs diff --git a/lite/backends/arm/math/sequence_pool.cc b/lite/backends/arm/math/sequence_pool.cc index b8f9ab0a1a..91a82bf962 100644 --- a/lite/backends/arm/math/sequence_pool.cc +++ b/lite/backends/arm/math/sequence_pool.cc @@ -46,12 +46,70 @@ void seq_pool_sum(const float* din, memcpy(dout_ptr, din_ptr, width * sizeof(float)); din_ptr += width; height = height - 1; +#if 0 for (int h = 0; h < height; h++) { for (int w = 0; w < width; ++w) { dout_ptr[w] += din_ptr[w]; } din_ptr += width; } +#else + int cnt_w = height >> 2; + int remain_w = height & 3; + int cnt_h = height >> 2; + int remain_h = height & 3; + int stride = width << 2; + for (int w = 0; w < cnt_w; w++) { + const float* din_ptr0 = din_ptr + w * 4; + float32x4_t dout_val = vld1q_f32(dout_ptr); + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + float32x4_t din1 = vld1q_f32(din_ptr1); + float32x4_t din2 = vld1q_f32(din_ptr2); + float32x4_t din3 = vld1q_f32(din_ptr3); + dout_val = vaddq_f32(din0, dout_val); + float32x4_t tmp = vaddq_f32(din1, din2); + din_ptr0 += stride; + din_ptr1 += stride; + dout_val = vaddq_f32(din3, dout_val); + din_ptr2 += stride; + din_ptr3 += stride; + dout_val = vaddq_f32(tmp, dout_val); + } + for (int h = 0; h < remain_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + dout_val = vaddq_f32(din0, dout_val); + din_ptr0 += width; + } + vst1q_f32(dout_ptr, dout_val); + dout_ptr += 4; + } + const float* din_ptr00 = din_ptr + cnt_w * 4; + for (int w = 0; w < remain_w; w++) { + const float* din_ptr0 = din_ptr00 + w; + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + *dout_ptr += din_ptr0[0]; + float tmp = din_ptr1[0] + din_ptr2[0]; + din_ptr0 += stride; + din_ptr1 += stride; + *dout_ptr += din_ptr3[0]; + din_ptr2 += stride; + din_ptr3 += stride; + *dout_ptr += tmp; + } + for (int h = 0; h < remain_h; h++) { + *dout_ptr += din_ptr0[0]; + din_ptr0 += width; + } + dout_ptr++; + } +#endif } } } @@ -144,13 +202,72 @@ void seq_pool_max(const float* din, } else { memcpy(dout_ptr, din_ptr, width * sizeof(float)); din_ptr += width; - int remain_h = height - 1; - for (int h = 0; h < remain_h; h++) { + height = height - 1; +#if 0 + for (int h = 0; h < rheight; h++) { for (int w = 0; w < width; w++) { dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]); } din_ptr += width; } +#else + int cnt_w = height >> 2; + int remain_w = height & 3; + int cnt_h = height >> 2; + int remain_h = height & 3; + int stride = width << 2; + for (int w = 0; w < cnt_w; w++) { + const float* din_ptr0 = din_ptr + w * 4; + float32x4_t dout_val = vld1q_f32(dout_ptr); + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + float32x4_t din1 = vld1q_f32(din_ptr1); + float32x4_t din2 = vld1q_f32(din_ptr2); + float32x4_t din3 = vld1q_f32(din_ptr3); + dout_val = vmaxq_f32(din0, dout_val); + float32x4_t tmp = vmaxq_f32(din1, din2); + din_ptr0 += stride; + din_ptr1 += stride; + dout_val = vmaxq_f32(din3, dout_val); + din_ptr2 += stride; + din_ptr3 += stride; + dout_val = vmaxq_f32(tmp, dout_val); + } + for (int h = 0; h < remain_h; h++) { + float32x4_t din0 = vld1q_f32(din_ptr0); + dout_val = vmaxq_f32(din0, dout_val); + din_ptr0 += width; + } + vst1q_f32(dout_ptr, dout_val); + dout_ptr += 4; + } + const float* din_ptr00 = din_ptr + cnt_w * 4; + for (int w = 0; w < remain_w; w++) { + const float* din_ptr0 = din_ptr00 + w; + const float* din_ptr1 = din_ptr0 + width; + const float* din_ptr2 = din_ptr1 + width; + const float* din_ptr3 = din_ptr2 + width; + for (int h = 0; h < cnt_h; h++) { + *dout_ptr += din_ptr0[0]; + *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); + float tmp = std::max(din_ptr1[0], din_ptr2[0]); + din_ptr0 += stride; + din_ptr1 += stride; + *dout_ptr = std::max(*dout_ptr, din_ptr3[0]); + din_ptr2 += stride; + din_ptr3 += stride; + *dout_ptr = std::max(*dout_ptr, tmp); + } + for (int h = 0; h < remain_h; h++) { + *dout_ptr = std::max(*dout_ptr, din_ptr0[0]); + din_ptr0 += width; + } + dout_ptr++; + } +#endif } } } diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc index 69740a258b..ef29ac2318 100644 --- a/lite/kernels/arm/sequence_conv_compute.cc +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/sgemm.h" #include "lite/core/op_registry.h" #include "lite/core/tensor.h" @@ -101,11 +102,19 @@ void SequenceConvCompute::Run() { 1, 1, // stride_h, stride_w, dilation_h, dilation_w tmp_data); +#if 0 local_naive_transpose(tmp_data, sub_col_data, kernel_size * hidden_dim, input_row_end - input_row_begin); - } +#else + paddle::lite::arm::math::transpose( + tmp_data, + sub_col_data, + kernel_size * hidden_dim, + input_row_end - input_row_begin); +#endif + } } // SGDMM C := alpha * A * B + beta * C -- GitLab