提交 15c7f5b2 编写于 作者: C chenjiaoAngel

fix compute error

上级 1207be18
......@@ -160,7 +160,6 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
float32x4_t din1 = vld1q_f32(din_ptr1);
float32x4_t din2 = vld1q_f32(din_ptr2);
float32x4_t din3 = vld1q_f32(din_ptr3);
dout_ptr0 += nn_num;
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
......@@ -171,17 +170,27 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
din_ptr0 += 4;
din_ptr1 += 4;
// a00 b00 c00 d00 a02 b02 c02 d02
float32x4x2_t tmp00 = vtrnq_f32(tmp0.val[0], tmp2.val[0]);
// a01 b01 c01 d01 a03 b03 c03 d03
float32x4x2_t tmp02 = vtrnq_f32(tmp0.val[1], tmp2.val[1]);
float tmp_val1 = tmp0.val[0][2];
float tmp_val2 = tmp0.val[0][3];
tmp0.val[0][2] = tmp2.val[0][0];
tmp0.val[0][3] = tmp2.val[0][1];
float tmp_val3 = tmp0.val[1][2];
float tmp_val4 = tmp0.val[1][3];
tmp2.val[0][0] = tmp_val1;
tmp2.val[0][1] = tmp_val2;
tmp0.val[1][2] = tmp2.val[1][0];
tmp0.val[1][3] = tmp2.val[1][1];
tmp2.val[1][0] = tmp_val3;
tmp2.val[1][1] = tmp_val4;
din_ptr2 += 4;
din_ptr3 += 4;
vst1q_f32(dout_ptr0, tmp00.val[0]);
vst1q_f32(dout_ptr1, tmp02.val[0]);
vst1q_f32(dout_ptr2, tmp00.val[1]);
vst1q_f32(dout_ptr3, tmp02.val[1]);
vst1q_f32(dout_ptr0, tmp0.val[0]);
vst1q_f32(dout_ptr1, tmp0.val[1]);
dout_ptr0 += nn_num;
vst1q_f32(dout_ptr2, tmp2.val[0]);
vst1q_f32(dout_ptr3, tmp2.val[1]);
}
dout_ptr0 += nn_num;
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
*dout_ptr0++ = *din_ptr1++;
......@@ -190,21 +199,21 @@ void transpose(const Dtype* din, Dtype* dout, int m, int n) {
}
}
const Dtype* din_ptr0 = din + cnt_n * mm_num;
dout = dout + cnt_n * 4;
for (int x = 0; x < remain_n; x++) {
Dtype* dout_ptr0 = dout + x * 4;
for (int y = 0; y < cnt_m; y++) {
float32x4_t din0 = vld1q_f32(din_ptr0);
dout_ptr0 += nn_num;
Dtype* dout_ptr1 = dout_ptr0 + n;
Dtype* dout_ptr2 = dout_ptr1 + n;
Dtype* dout_ptr3 = dout_ptr2 + n;
din_ptr0 += 4;
*dout_ptr0 = din0[0];
*dout_ptr1 = din0[1];
dout_ptr0 += nn_num;
*dout_ptr2 = din0[2];
*dout_ptr3 = din0[3];
}
dout_ptr0 += nn_num;
for (int y = 0; y < remain_m; y++) {
*dout_ptr0++ = *din_ptr0++;
}
......
......@@ -102,17 +102,10 @@ void SequenceConvCompute::Run() {
1,
1, // stride_h, stride_w, dilation_h, dilation_w
tmp_data);
#if 0
local_naive_transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#else
paddle::lite::arm::math::transpose(tmp_data,
sub_col_data,
kernel_size * hidden_dim,
input_row_end - input_row_begin);
#endif
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册