提交 4403de09 编写于 作者: C chenjiaoAngel

fix format, testt=develop

上级 13a95dc6
......@@ -147,8 +147,8 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
int remain_n = n & 3;
int cnt_m = m >> 2;
int remain_m = m & 3;
int nn_num = n << 2; // n * 4
int mm_num = m << 2; // m * 4
int nn_num = n << 2; // n * 4
int mm_num = m << 2; // m * 4
for (int x = 0; x < cnt_n; x++) {
const Dtype* din_ptr0 = din + x * mm_num;
const Dtype* din_ptr1 = din_ptr0 + m;
......@@ -156,7 +156,7 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
const Dtype* din_ptr3 = din_ptr2 + m;
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
......@@ -164,16 +164,16 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
float32x4x2_t tmp0 = vtrnq_f32(din0, din1); // a00 b00 a02 b02
//float32x4_t tmp1 = vtrn2q_f32(din0, din1); // a01 b01 a03 b03
float32x4x2_t tmp2 = vtrnq_f32(din2, din3); // c00 d00 c02 d02
// float32x4_t tmp3 = vtrn2q_f32(din2, din3); // c01 d01 c03 d03
// a00 b00 a02 b02 a01 b01 a03 b03
float32x4x2_t tmp0 = vtrnq_f32(din0, din1);
// c00 d00 c02 d02 c01 d01 c03 d03
float32x4x2_t tmp2 = vtrnq_f32(din2, din3);
din_ptr0 += 4;
din_ptr1 += 4;
float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]); // a00 b00 c00 d00
// float32x4_t tmp01 = vtrn2q_f32(tmp0, tmp2); // a02 b02 c02 d02
float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]); // a01 b01 c01 d01
// float32x4_t tmp03 = vtrn2q_f32(tmp1, tmp3); // a03 b03 c03 d03
// a00 b00 c00 d00 a02 b02 c02 d02
float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]);
// a01 b01 c01 d01 a03 b03 c03 d03
float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]);
din_ptr2 += 4;
din_ptr3 += 4;
vst1q_f32(dout_ptr0, tmp00.val[0]);
......@@ -193,7 +193,7 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
for (int x = 0; x < remain_n; x++) {
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0); // a00 a01 a02 a03
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_ptr0 += nn_num;
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
......
......@@ -54,7 +54,7 @@ void seq_pool_sum<float>(const float* din,
din_ptr += width;
}
#else
int cnt_w = height >> 2;
int cnt_w = height >> 2;
int remain_w = height & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
......@@ -211,62 +211,62 @@ void seq_pool_max<float>(const float* din,
din_ptr += width;
}
#else
int cnt_w = height >> 2;
int remain_w = height & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
int cnt_w = height >> 2;
int remain_w = height & 3;
int cnt_h = height >> 2;
int remain_h = height & 3;
int stride = width << 2;
for (int w = 0; w < cnt_w; w++) {
const float* din_ptr0 = din_ptr + w * 4;
float32x4_t dout_val = vld1q_f32(dout_ptr);
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_val = vmaxq_f32(din0, dout_val);
float32x4_t tmp = vmaxq_f32(din1, din2);
din_ptr0 += stride;
din_ptr1 += stride;
dout_val = vmaxq_f32(din3, dout_val);
din_ptr2 += stride;
din_ptr3 += stride;
dout_val = vmaxq_f32(tmp, dout_val);
}
for (int h = 0; h < remain_h; h++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_val = vmaxq_f32(din0, dout_val);
din_ptr0 += width;
}
vst1q_f32(dout_ptr, dout_val);
dout_ptr += 4;
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
const float* din_ptr00 = din_ptr + cnt_w * 4;
for (int w = 0; w < remain_w; w++) {
const float* din_ptr0 = din_ptr00 + w;
const float* din_ptr1 = din_ptr0 + width;
const float* din_ptr2 = din_ptr1 + width;
const float* din_ptr3 = din_ptr2 + width;
for (int h = 0; h < cnt_h; h++) {
*dout_ptr += din_ptr0[0];
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
float tmp = std::max(din_ptr1[0], din_ptr2[0]);
din_ptr0 += stride;
din_ptr1 += stride;
*dout_ptr = std::max(*dout_ptr, din_ptr3[0]);
din_ptr2 += stride;
din_ptr3 += stride;
*dout_ptr = std::max(*dout_ptr, tmp);
}
for (int h = 0; h < remain_h; h++) {
*dout_ptr = std::max(*dout_ptr, din_ptr0[0]);
din_ptr0 += width;
}
dout_ptr++;
}
dout_ptr++;
}
#endif
}
}
......
......@@ -17,8 +17,8 @@ limitations under the License. */
#include <cstddef>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
......@@ -108,13 +108,12 @@ void SequenceConvCompute::Run() {
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#else
paddle::lite::arm::math::transpose(
tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#endif
}
paddle::lite::arm::math::transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#endif
}
}
// SGDMM C := alpha * A * B + beta * C
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册