From 819ccb8ada1b398a224a74e010064b0cbe0e0690 Mon Sep 17 00:00:00 2001 From: yongqiang Date: Tue, 28 Jul 2020 15:02:20 +0000 Subject: [PATCH] fix transpose error. test=develop --- lite/kernels/arm/transpose_compute.cc | 83 ++++++++++++++++++--------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/lite/kernels/arm/transpose_compute.cc b/lite/kernels/arm/transpose_compute.cc index 5e139a2693..c445df453b 100644 --- a/lite/kernels/arm/transpose_compute.cc +++ b/lite/kernels/arm/transpose_compute.cc @@ -25,26 +25,6 @@ namespace lite { namespace kernels { namespace arm { -template -void trans_basic(const int count, - const Dtype* din, - const int* permute_order, - const int* old_steps, - const int* new_steps, - const int num_axes, - Dtype* dout) { - for (int i = 0; i < count; ++i) { - int old_idx = 0; - int idx = i; - for (int j = 0; j < num_axes; ++j) { - int order = permute_order[j]; - old_idx += (idx / new_steps[j]) * old_steps[order]; - idx %= new_steps[j]; - } - dout[i] = din[old_idx]; - } -} - template void transpose_mat(const Dtype* din, Dtype* dout, @@ -201,6 +181,61 @@ void TransposeCompute::PrepareForRun() { _old_steps = get_stride(input->dims()); } } + +template +void TransposeCompute_(const std::vector& axis, + const lite::Tensor* input, + lite::Tensor* output) { + // const Dtype *input_ptr = input->data(); + const Dtype* input_ptr = input->data(); + Dtype* output_ptr = output->mutable_data(); + + // input and output's shape dimension must >= 2 && <= 6. + const DDim& in_dim = input->dims(); + const DDim& out_dim = output->dims(); + + // precompute inverted output dim and strides + size_t rout_dim[6], strides[6]; + int permute = axis.size(); // permute must >=2 && <= 6. + for (int i = 0; i < permute; ++i) { + int k = permute - 1 - i; + strides[k] = 1; + for (int j = axis[i] + 1; j < permute; ++j) { + strides[k] *= in_dim[j]; + } + rout_dim[k] = out_dim[i]; + } + + // unroll the first 2 dimensions + int reamin_dim = 1; + for (int i = 2; i < out_dim.size(); ++i) { + reamin_dim *= out_dim[i]; + } + +#pragma omp parallel for collapse(2) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int j = 0; j < out_dim[1]; ++j) { + size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; + Dtype* out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; + int indics[4] = {0, 0, 0, 0}; + for (int k = 0; k < reamin_dim; ++k) { + out_ptr[k] = input_ptr[offset]; + indics[0] += 1; + offset += strides[0]; + for (int p = 0; p < permute - 3; ++p) { + if (indics[p] == rout_dim[p]) { + indics[p + 1] += 1; + indics[p] = 0; + offset += strides[p + 1]; + offset -= rout_dim[p] * strides[p]; + } else { + break; + } + } + } + } + } +} // Transpose void TransposeCompute::Run() { auto& param = Param(); @@ -220,13 +255,7 @@ void TransposeCompute::Run() { if (trans_mat) { transpose_mat(din, dout, _trans_num, _trans_w, _trans_h); } else { - trans_basic(output->numel(), - din, - param.axis.data(), - _old_steps.data(), - _new_steps.data(), - input->dims().size(), - dout); + TransposeCompute_(axis, input, output); } } -- GitLab