提交 13a95dc6 编写于 作者: C chenjiaoAngel

uupdate sequence_pool and sequence_conv profiler, test=develop

上级 aa418461
......@@ -139,6 +139,77 @@ static bool conv_trans_weights_numc(const dtype* din,
}
return true;
}
template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn
// 4x4 分块处理
int cnt_n = n >> 2;
int remain_n = n & 3;
int cnt_m = m >> 2;
int remain_m = m & 3;
int nn_num = n << 2; // n * 4
int mm_num = m << 2; // m * 4
for (int x = 0; x < cnt_n; x++) {
const Dtype* din_ptr0 = din + x * mm_num;
const Dtype* din_ptr1 = din_ptr0 + m;
const Dtype* din_ptr2 = din_ptr1 + m;
const Dtype* din_ptr3 = din_ptr2 + m;
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_ptr0 += nn_num;
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
float32x4x2_t tmp0 = vtrnq_f32(din0, din1); // a00 b00 a02 b02
//float32x4_t tmp1 = vtrn2q_f32(din0, din1); // a01 b01 a03 b03
float32x4x2_t tmp2 = vtrnq_f32(din2, din3); // c00 d00 c02 d02
// float32x4_t tmp3 = vtrn2q_f32(din2, din3); // c01 d01 c03 d03
din_ptr0 += 4;
din_ptr1 += 4;
float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); // a00 b00 c00 d00
// float32x4_t tmp01 = vtrn2q_f32(tmp0, tmp2); // a02 b02 c02 d02
float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); // a01 b01 c01 d01
// float32x4_t tmp03 = vtrn2q_f32(tmp1, tmp3); // a03 b03 c03 d03
din_ptr2 += 4;
din_ptr3 += 4;
vst1q_f32(dout_ptr0, tmp00.val[0]);
vst1q_f32(dout_ptr1, tmp02.val[0]);
vst1q_f32(dout_ptr2, tmp00.val[1]);
vst1q_f32(dout_ptr3, tmp02.val[1]);
}
dout_ptr0 += nn_num;
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
*dout_ptr0++ = *din_ptr1++;
*dout_ptr0++ = *din_ptr2++;
*dout_ptr0++ = *din_ptr3++;
}
}
const Dtype* din_ptr0 = din + cnt_n * mm_num;
for (int x = 0; x < remain_n; x++) {
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
dout_ptr0 += nn_num;
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
din_ptr0 += 4;
*dout_ptr0 = din0[0];
*dout_ptr1 = din0[1];
*dout_ptr2 = din0[2];
*dout_ptr3 = din0[3];
}
dout_ptr0 += nn_num;
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
}
}
}
/*preprocessing inputs
* input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws]
* n = he - hs
......
......@@ -46,12 +46,70 @@ void seq_pool_sum<float>(const float* din,
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
height = height - 1;
#if 0
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; ++w) {
dout_ptr[w] += din_ptr[w];
}
din_ptr += width;
}
#else
int cnt_w = height >> 2;
int remain_w = height & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vaddq_f32(din0, dout_val);
float32x4_t tmp = vaddq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vaddq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vaddq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vaddq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
float tmp = din_ptr1[0] + din_ptr2[0];
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr += din_ptr3[0];
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr += tmp;
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr += din_ptr0[0];
din_ptr0 += width;
}
dout_ptr++;
}
#endif
}
}
}
......@@ -144,13 +202,72 @@ void seq_pool_max<float>(const float* din,
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
int remain_h = height - 1;
for (int h = 0; h < remain_h; h++) {
height = height - 1;
#if 0
for (int h = 0; h < rheight; h++) {
for (int w = 0; w < width; w++) {
dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]);
}
din_ptr += width;
}
#else
int cnt_w = height >> 2;
int remain_w = height & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
}
dout_ptr++;
}
#endif
}
}
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
......@@ -101,11 +102,19 @@ void SequenceConvCompute::Run() {
1,
1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data);
#if 0
local_naive_transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
}
#else
paddle::lite::arm::math::transpose(
tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#endif
}
}
// SGDMM C := alpha * A * B + beta * C
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册