提交 8a3fd00b 编写于 作者: Y yongqiang

fix transpose error. test=develop

上级 89d573e1
...@@ -25,61 +25,174 @@ namespace lite { ...@@ -25,61 +25,174 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
bool IsShuffleChannel(const std::vector<int> &axis) { template <typename Dtype>
bool is_shuffle_channel = true; void transpose_mat(const Dtype* din,
if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { Dtype* dout,
for (int i = 3; i < axis.size(); ++i) { const int num,
if (axis[i] != i) { const int width,
is_shuffle_channel = false; const int height);
break; void transpose_mat(const float* din,
float* dout,
const int num,
const int width,
const int height) {
int nw = width >> 2;
int nh = height >> 2;
int size_in = width * height;
for (int i = 0; i < num; ++i) {
float* ptr_out = dout + i * size_in;
const float* ptr_in = din + i * size_in;
#pragma omp parallel for
for (int h = 0; h < nh; h++) {
const float* ptr_din_row = ptr_in + h * 4 * width;
for (int w = 0; w < nw; w++) {
float* data_out_ptr = ptr_out + w * 4 * height + h * 4;
const float* din0 = ptr_din_row;
const float* din1 = din0 + width;
const float* din2 = din1 + width;
const float* din3 = din2 + width;
float* dout0 = data_out_ptr;
float* dout1 = dout0 + height;
float* dout2 = dout1 + height;
float* dout3 = dout2 + height;
#ifdef __aarch64__
float32x4_t vr0 = vld1q_f32(din0);
float32x4_t vr1 = vld1q_f32(din1);
float32x4_t vr2 = vld1q_f32(din2);
float32x4_t vr3 = vld1q_f32(din3);
float32x4_t re0 = vtrn1q_f32(vr0, vr1);
float32x4_t re1 = vtrn2q_f32(vr0, vr1);
float32x4_t re2 = vtrn1q_f32(vr2, vr3);
float32x4_t re3 = vtrn2q_f32(vr2, vr3);
vst1_f32(dout0, vget_low_f32(re0));
dout0 += 2;
vst1_f32(dout0, vget_low_f32(re2));
vst1_f32(dout1, vget_low_f32(re1));
dout1 += 2;
vst1_f32(dout1, vget_low_f32(re3));
vst1_f32(dout2, vget_high_f32(re0));
dout2 += 2;
vst1_f32(dout2, vget_high_f32(re2));
vst1_f32(dout3, vget_high_f32(re1));
dout3 += 2;
vst1_f32(dout3, vget_high_f32(re3));
#else
asm("vld1.32 {d0, d1}, [%[in0]] \n"
"vld1.32 {d2, d3}, [%[in1]] \n"
"vld1.32 {d4, d5}, [%[in2]] \n"
"vld1.32 {d6, d7}, [%[in3]] \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vswp d1, d4 \n"
"vswp d3, d6 \n"
"vst1.32 {d0, d1}, [%[out0]] \n"
"vst1.32 {d2, d3}, [%[out1]] \n"
"vst1.32 {d4, d5}, [%[out2]] \n"
"vst1.32 {d6, d7}, [%[out3]] \n"
:
: [out0] "r"(dout0),
[out1] "r"(dout1),
[out2] "r"(dout2),
[out3] "r"(dout3),
[in0] "r"(din0),
[in1] "r"(din1),
[in2] "r"(din2),
[in3] "r"(din3)
: "q0", "q1", "q2", "q3");
#endif
ptr_din_row += 4;
}
}
// remian
for (int h = 0; h < height; h++) {
for (int w = nw * 4; w < width; w++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
}
}
for (int w = 0; w < width; w++) {
for (int h = nh * 4; h < height; h++) {
const float* data_in_ptr = ptr_in + h * width + w;
float* data_out_ptr = ptr_out + w * height + h;
*data_out_ptr = *data_in_ptr;
} }
} }
} else {
return false;
} }
return is_shuffle_channel;
} }
template <typename Dtype> std::vector<int> get_stride(const paddle::lite::DDimLite& dims) {
void ShuffleChannelCompute(const std::vector<int> &axis, std::vector<int> data_stride{0};
const lite::Tensor *input,
lite::Tensor *output) { for (int i = 0; i < dims.size(); ++i) {
const Dtype *input_ptr = input->data<Dtype>(); data_stride.push_back(dims.count(i, dims.size()));
Dtype *output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims();
const DDim &out_dim = output->dims();
size_t offset = 1;
for (int i = 3; i < axis.size(); ++i) {
offset *= in_dim[i];
} }
return data_stride;
}
#pragma omp parallel for collapse(3) void TransposeCompute::PrepareForRun() {
for (int batch = 0; batch < out_dim[0]; ++batch) { auto& param = Param<operators::TransposeParam>();
for (int c1 = 0; c1 < out_dim[1]; ++c1) { auto* input = param.x;
for (int c2 = 0; c2 < out_dim[2]; ++c2) { auto* output = param.output;
size_t out_offset =
((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset; int _num_axes = input->dims().size();
size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset; CHECK(_num_axes == param.axis.size())
memcpy(output_ptr + out_offset, << "axis size is not match to input dims";
input_ptr + in_offset,
offset * sizeof(Dtype)); need_trans = false;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[i] != i) {
need_trans = true;
break;
}
}
if (!need_trans) {
return;
}
std::vector<int> axis_diff;
int j = 0;
for (int i = 0; i < _num_axes; ++i) {
if (param.axis[j] != i) {
axis_diff.push_back(j);
} else {
j++;
}
} }
for (int i = 0; i < axis_diff.size(); i++) {
} }
if (input->dims().count(axis_diff[0], _num_axes) == 1) {
need_trans = false;
return;
}
if (axis_diff.size() == 1) {
trans_mat = true;
_trans_num = input->dims().count(0, std::max(axis_diff[0], 0));
_trans_w = input->dims().count(axis_diff[0] + 1, _num_axes);
_trans_h = input->dims()[axis_diff[0]];
} else {
trans_mat = false;
_new_steps = get_stride(output->dims());
_old_steps = get_stride(input->dims());
} }
} }
template <typename Dtype> template <typename Dtype>
void TransposeCompute_(const std::vector<int> &axis, void TransposeCompute_(const std::vector<int>& axis,
const lite::Tensor *input, const lite::Tensor* input,
lite::Tensor *output) { lite::Tensor* output) {
// const Dtype *input_ptr = input->data<Dtype>(); // const Dtype *input_ptr = input->data<Dtype>();
const Dtype *input_ptr = input->data<float>(); const Dtype* input_ptr = input->data<float>();
Dtype *output_ptr = output->mutable_data<Dtype>(); Dtype* output_ptr = output->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6. // input and output's shape dimension must >= 2 && <= 6.
const DDim &in_dim = input->dims(); const DDim& in_dim = input->dims();
const DDim &out_dim = output->dims(); const DDim& out_dim = output->dims();
// precompute inverted output dim and strides // precompute inverted output dim and strides
size_t rout_dim[6], strides[6]; size_t rout_dim[6], strides[6];
...@@ -103,7 +216,7 @@ void TransposeCompute_(const std::vector<int> &axis, ...@@ -103,7 +216,7 @@ void TransposeCompute_(const std::vector<int> &axis,
for (int batch = 0; batch < out_dim[0]; ++batch) { for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int j = 0; j < out_dim[1]; ++j) { for (int j = 0; j < out_dim[1]; ++j) {
size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; size_t offset = batch * strides[permute - 1] + j * strides[permute - 2];
Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; Dtype* out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim;
int indics[4] = {0, 0, 0, 0}; int indics[4] = {0, 0, 0, 0};
for (int k = 0; k < reamin_dim; ++k) { for (int k = 0; k < reamin_dim; ++k) {
out_ptr[k] = input_ptr[offset]; out_ptr[k] = input_ptr[offset];
...@@ -123,37 +236,27 @@ void TransposeCompute_(const std::vector<int> &axis, ...@@ -123,37 +236,27 @@ void TransposeCompute_(const std::vector<int> &axis,
} }
} }
} }
// Transpose // Transpose
void TransposeCompute::Run() { void TransposeCompute::Run() {
auto &param = Param<operators::TransposeParam>(); auto& param = Param<operators::TransposeParam>();
auto *input = param.x; auto* input = param.x;
auto *output = param.output; auto* output = param.output;
const std::vector<int> axis = param.axis; const std::vector<int> axis = param.axis;
bool shuffle_channel = IsShuffleChannel(axis); //! only copy the data
if (shuffle_channel) { if (!need_trans) {
ShuffleChannelCompute<float>(axis, input, output); output->CopyDataFrom(*input);
} else {
TransposeCompute_<float>(axis, input, output);
}
return; return;
} }
// Transpose2
void Transpose2Compute::Run() {
auto &param = Param<operators::TransposeParam>();
auto *input = param.x;
auto *output = param.output;
const std::vector<int> axis = param.axis;
bool shuffle_channel = IsShuffleChannel(axis); const float* din = static_cast<const float*>(input->data<float>());
if (shuffle_channel) { float* dout = static_cast<float*>(output->mutable_data<float>());
ShuffleChannelCompute<float>(axis, input, output); //! transpose the data
if (trans_mat) {
transpose_mat(din, dout, _trans_num, _trans_w, _trans_h);
} else { } else {
TransposeCompute_<float>(axis, input, output); TransposeCompute_<float>(axis, input, output);
} }
return;
} }
} // namespace arm } // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册