未验证 提交 bebe26e5 编写于 作者: H HappyAngel 提交者: GitHub

[arm]update sequence_pool and sequence_conv profiler (#3934)

* update sequence_pool and sequence_conv profiler, test=develop

* fix compute error, test=develop

* delete warning and extra info, test=develop
上级 611c9b37
......@@ -139,6 +139,86 @@ static bool conv_trans_weights_numc(const dtype* din,
}
return true;
}
template <typename Dtype>
void transpose(const Dtype* din, Dtype* dout, int m, int n) {
// nxm == mxn
// 4x4
int cnt_n = n >> 2;
int remain_n = n & 3;
int cnt_m = m >> 2;
int remain_m = m & 3;
int nn_num = n << 2; // n * 4
int mm_num = m << 2; // m * 4
for (int x = 0; x < cnt_n; x++) {
const Dtype* din_ptr0 = din + x * mm_num;
const Dtype* din_ptr1 = din_ptr0 + m;
const Dtype* din_ptr2 = din_ptr1 + m;
const Dtype* din_ptr3 = din_ptr2 + m;
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(din0, din1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(din2, din3);
din_ptr0 += 4;
din_ptr1 += 4;
// a00 b00 c00 d00 a02 b02 c02 d02
// a01 b01 c01 d01 a03 b03 c03 d03
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
din_ptr2 += 4;
din_ptr3 += 4;
vst1q_f32(dout_ptr0, tmp0.val[0]);
vst1q_f32(dout_ptr1, tmp0.val[1]);
dout_ptr0 += nn_num;
vst1q_f32(dout_ptr2, tmp2.val[0]);
vst1q_f32(dout_ptr3, tmp2.val[1]);
}
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
*dout_ptr0++ = *din_ptr1++;
*dout_ptr0++ = *din_ptr2++;
*dout_ptr0++ = *din_ptr3++;
}
}
const Dtype* din_ptr0 = din + cnt_n * mm_num;
dout = dout + cnt_n * 4;
for (int x = 0; x < remain_n; x++) {
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
din_ptr0 += 4;
*dout_ptr0 = din0[0];
*dout_ptr1 = din0[1];
dout_ptr0 += nn_num;
*dout_ptr2 = din0[2];
*dout_ptr3 = din0[3];
}
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
}
}
}
/*preprocessing inputs
* input din: [1, chin, he-hs, we - ws] --> outputs dout: [n, chin, 1, we - ws]
* n = he - hs
......
......@@ -2044,7 +2044,7 @@ void pooling3x3s1p0_avg(const float* din,
} else {
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom = 1) {
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.f;
......
......@@ -46,11 +46,60 @@ void seq_pool_sum<float>(const float* din,
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
height = height - 1;
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; ++w) {
dout_ptr[w] += din_ptr[w];
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vaddq_f32(din0, dout_val);
float32x4_t tmp = vaddq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vaddq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vaddq_f32(tmp, dout_val);
}
din_ptr += width;
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vaddq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
float tmp = din_ptr1[0] + din_ptr2[0];
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr += din_ptr3[0];
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr += tmp;
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr += din_ptr0[0];
din_ptr0 += width;
}
dout_ptr++;
}
}
}
......@@ -144,12 +193,62 @@ void seq_pool_max<float>(const float* din,
} else {
memcpy(dout_ptr, din_ptr, width * sizeof(float));
din_ptr += width;
int remain_h = height - 1;
for (int h = 0; h < remain_h; h++) {
for (int w = 0; w < width; w++) {
dout_ptr[w] = std::max(dout_ptr[w], din_ptr[w]);
height = height - 1;
int cnt_w = width >> 2;
int remain_w = width & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
din_ptr += width;
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
}
dout_ptr++;
}
}
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cstddef>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册