diff --git a/lite/kernels/arm/transpose_compute.cc b/lite/kernels/arm/transpose_compute.cc index 8d1964de530ec3955b149fcd71286f8dfadcd1f9..c445df453b82f8abca744619cc5ffee9bbd2b692 100644 --- a/lite/kernels/arm/transpose_compute.cc +++ b/lite/kernels/arm/transpose_compute.cc @@ -25,61 +25,174 @@ namespace lite { namespace kernels { namespace arm { -bool IsShuffleChannel(const std::vector &axis) { - bool is_shuffle_channel = true; - if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { - for (int i = 3; i < axis.size(); ++i) { - if (axis[i] != i) { - is_shuffle_channel = false; - break; +template +void transpose_mat(const Dtype* din, + Dtype* dout, + const int num, + const int width, + const int height); +void transpose_mat(const float* din, + float* dout, + const int num, + const int width, + const int height) { + int nw = width >> 2; + int nh = height >> 2; + int size_in = width * height; + + for (int i = 0; i < num; ++i) { + float* ptr_out = dout + i * size_in; + const float* ptr_in = din + i * size_in; +#pragma omp parallel for + for (int h = 0; h < nh; h++) { + const float* ptr_din_row = ptr_in + h * 4 * width; + for (int w = 0; w < nw; w++) { + float* data_out_ptr = ptr_out + w * 4 * height + h * 4; + const float* din0 = ptr_din_row; + const float* din1 = din0 + width; + const float* din2 = din1 + width; + const float* din3 = din2 + width; + + float* dout0 = data_out_ptr; + float* dout1 = dout0 + height; + float* dout2 = dout1 + height; + float* dout3 = dout2 + height; +#ifdef __aarch64__ + float32x4_t vr0 = vld1q_f32(din0); + float32x4_t vr1 = vld1q_f32(din1); + float32x4_t vr2 = vld1q_f32(din2); + float32x4_t vr3 = vld1q_f32(din3); + float32x4_t re0 = vtrn1q_f32(vr0, vr1); + float32x4_t re1 = vtrn2q_f32(vr0, vr1); + float32x4_t re2 = vtrn1q_f32(vr2, vr3); + float32x4_t re3 = vtrn2q_f32(vr2, vr3); + vst1_f32(dout0, vget_low_f32(re0)); + dout0 += 2; + vst1_f32(dout0, vget_low_f32(re2)); + vst1_f32(dout1, vget_low_f32(re1)); + dout1 += 2; + vst1_f32(dout1, vget_low_f32(re3)); + vst1_f32(dout2, vget_high_f32(re0)); + dout2 += 2; + vst1_f32(dout2, vget_high_f32(re2)); + vst1_f32(dout3, vget_high_f32(re1)); + dout3 += 2; + vst1_f32(dout3, vget_high_f32(re3)); +#else + asm("vld1.32 {d0, d1}, [%[in0]] \n" + "vld1.32 {d2, d3}, [%[in1]] \n" + "vld1.32 {d4, d5}, [%[in2]] \n" + "vld1.32 {d6, d7}, [%[in3]] \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vswp d1, d4 \n" + "vswp d3, d6 \n" + "vst1.32 {d0, d1}, [%[out0]] \n" + "vst1.32 {d2, d3}, [%[out1]] \n" + "vst1.32 {d4, d5}, [%[out2]] \n" + "vst1.32 {d6, d7}, [%[out3]] \n" + : + : [out0] "r"(dout0), + [out1] "r"(dout1), + [out2] "r"(dout2), + [out3] "r"(dout3), + [in0] "r"(din0), + [in1] "r"(din1), + [in2] "r"(din2), + [in3] "r"(din3) + : "q0", "q1", "q2", "q3"); +#endif + ptr_din_row += 4; + } + } + // remian + for (int h = 0; h < height; h++) { + for (int w = nw * 4; w < width; w++) { + const float* data_in_ptr = ptr_in + h * width + w; + float* data_out_ptr = ptr_out + w * height + h; + *data_out_ptr = *data_in_ptr; + } + } + for (int w = 0; w < width; w++) { + for (int h = nh * 4; h < height; h++) { + const float* data_in_ptr = ptr_in + h * width + w; + float* data_out_ptr = ptr_out + w * height + h; + *data_out_ptr = *data_in_ptr; } } - } else { - return false; } - return is_shuffle_channel; } -template -void ShuffleChannelCompute(const std::vector &axis, - const lite::Tensor *input, - lite::Tensor *output) { - const Dtype *input_ptr = input->data(); - Dtype *output_ptr = output->mutable_data(); - // input and output's shape dimension must >= 2 && <= 6. - const DDim &in_dim = input->dims(); - const DDim &out_dim = output->dims(); - size_t offset = 1; - for (int i = 3; i < axis.size(); ++i) { - offset *= in_dim[i]; +std::vector get_stride(const paddle::lite::DDimLite& dims) { + std::vector data_stride{0}; + + for (int i = 0; i < dims.size(); ++i) { + data_stride.push_back(dims.count(i, dims.size())); } + return data_stride; +} -#pragma omp parallel for collapse(3) - for (int batch = 0; batch < out_dim[0]; ++batch) { - for (int c1 = 0; c1 < out_dim[1]; ++c1) { - for (int c2 = 0; c2 < out_dim[2]; ++c2) { - size_t out_offset = - ((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset; - size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset; - memcpy(output_ptr + out_offset, - input_ptr + in_offset, - offset * sizeof(Dtype)); - } +void TransposeCompute::PrepareForRun() { + auto& param = Param(); + auto* input = param.x; + auto* output = param.output; + + int _num_axes = input->dims().size(); + CHECK(_num_axes == param.axis.size()) + << "axis size is not match to input dims"; + + need_trans = false; + for (int i = 0; i < _num_axes; ++i) { + if (param.axis[i] != i) { + need_trans = true; + break; + } + } + + if (!need_trans) { + return; + } + + std::vector axis_diff; + int j = 0; + for (int i = 0; i < _num_axes; ++i) { + if (param.axis[j] != i) { + axis_diff.push_back(j); + } else { + j++; } } + for (int i = 0; i < axis_diff.size(); i++) { + } + if (input->dims().count(axis_diff[0], _num_axes) == 1) { + need_trans = false; + return; + } + + if (axis_diff.size() == 1) { + trans_mat = true; + _trans_num = input->dims().count(0, std::max(axis_diff[0], 0)); + _trans_w = input->dims().count(axis_diff[0] + 1, _num_axes); + _trans_h = input->dims()[axis_diff[0]]; + + } else { + trans_mat = false; + _new_steps = get_stride(output->dims()); + _old_steps = get_stride(input->dims()); + } } template -void TransposeCompute_(const std::vector &axis, - const lite::Tensor *input, - lite::Tensor *output) { +void TransposeCompute_(const std::vector& axis, + const lite::Tensor* input, + lite::Tensor* output) { // const Dtype *input_ptr = input->data(); - const Dtype *input_ptr = input->data(); - Dtype *output_ptr = output->mutable_data(); + const Dtype* input_ptr = input->data(); + Dtype* output_ptr = output->mutable_data(); // input and output's shape dimension must >= 2 && <= 6. - const DDim &in_dim = input->dims(); - const DDim &out_dim = output->dims(); + const DDim& in_dim = input->dims(); + const DDim& out_dim = output->dims(); // precompute inverted output dim and strides size_t rout_dim[6], strides[6]; @@ -103,7 +216,7 @@ void TransposeCompute_(const std::vector &axis, for (int batch = 0; batch < out_dim[0]; ++batch) { for (int j = 0; j < out_dim[1]; ++j) { size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; - Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; + Dtype* out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; int indics[4] = {0, 0, 0, 0}; for (int k = 0; k < reamin_dim; ++k) { out_ptr[k] = input_ptr[offset]; @@ -123,37 +236,27 @@ void TransposeCompute_(const std::vector &axis, } } } - // Transpose void TransposeCompute::Run() { - auto ¶m = Param(); - auto *input = param.x; - auto *output = param.output; + auto& param = Param(); + auto* input = param.x; + auto* output = param.output; const std::vector axis = param.axis; - bool shuffle_channel = IsShuffleChannel(axis); - if (shuffle_channel) { - ShuffleChannelCompute(axis, input, output); - } else { - TransposeCompute_(axis, input, output); + //! only copy the data + if (!need_trans) { + output->CopyDataFrom(*input); + return; } - return; -} - -// Transpose2 -void Transpose2Compute::Run() { - auto ¶m = Param(); - auto *input = param.x; - auto *output = param.output; - const std::vector axis = param.axis; - bool shuffle_channel = IsShuffleChannel(axis); - if (shuffle_channel) { - ShuffleChannelCompute(axis, input, output); + const float* din = static_cast(input->data()); + float* dout = static_cast(output->mutable_data()); + //! transpose the data + if (trans_mat) { + transpose_mat(din, dout, _trans_num, _trans_w, _trans_h); } else { TransposeCompute_(axis, input, output); } - return; } } // namespace arm