diff --git a/src/operators/kernel/arm/transpose_kernel.cpp b/src/operators/kernel/arm/transpose_kernel.cpp index 3ebe261fb8fe511022d6efbf4641898ef326319f..1b41968f40d036d55b98298a76564dcc12576571 100644 --- a/src/operators/kernel/arm/transpose_kernel.cpp +++ b/src/operators/kernel/arm/transpose_kernel.cpp @@ -11,29 +11,28 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #ifdef TRANSPOSE_OP #include "operators/kernel/transpose_kernel.h" - namespace paddle_mobile { namespace operators { -template -void TransposeFunc(const int numel, const T* input, const vector axis, - const vector old_strides, const vector new_strides, - T* output) { - for (int i = 0; i < numel; ++i) { - int old_idx = 0; - int idx = i; - for (int j = 0; j < axis.size(); ++j) { - int order = axis[j]; - old_idx += (idx / new_strides[j]) * old_strides[order]; - idx %= new_strides[j]; - } - output[i] = input[old_idx]; - } -} +// vector pos; +// template +// void TransposeFunc(const int numel, const T* input, const vector axis, +// const vector old_strides, const vector +// new_strides, T* output) { +// for (int i = 0; i < numel; ++i) { +// int old_idx = 0; +// int idx = i; +// for (int j = 0; j < axis.size(); ++j) { +// int order = axis[j]; +// old_idx += (idx / new_strides[j]) * old_strides[order]; +// idx %= new_strides[j]; +// } +// output[i] = input[old_idx]; +// } +// } template <> void TransposeKernel::Compute(const TransposeParam& param) const { @@ -44,28 +43,38 @@ void TransposeKernel::Compute(const TransposeParam& param) const { const auto* input_x_data = input_x->data(); auto* out_data = out->mutable_data(); - size_t axis_size = axis.size(); - std::vector new_dims; - new_dims.reserve(axis_size); - for (auto c : axis) { - new_dims.push_back(input_x_dims[c]); + size_t ndim = axis.size(); + std::vector xdim(ndim); + std::vector xstride(ndim); + std::vector xout(ndim); + for (int i = 0; i < ndim; i++) { + int j = ndim - 1 - i; + xdim[j] = input_x_dims[axis[i]]; + xstride[j] = 1; + for (int k = axis[i] + 1; k < ndim; k++) { + xstride[j] *= input_x_dims[k]; + } + xout[j] = xstride[j] * xdim[j]; } - std::vector old_strides; - std::vector new_strides; - for (int i = 0; i < axis.size(); i++) { - int temp_old = 1; - int temp_new = 1; - for (int j = i + 1; j < axis.size(); j++) { - temp_old *= input_x_dims[j]; - temp_new *= new_dims[j]; + auto numel = input_x->numel(); + size_t pind = 0; + std::vector ind(ndim); + for (int i = 0; i < numel; i++) { + out_data[i] = input_x_data[pind]; + ind[0]++; + pind += xstride[0]; + for (int j = 0; j < ndim - 1; j++) { + if (ind[j] == xdim[j]) { + ind[j + 1]++; + ind[j] = 0; + pind += xstride[j + 1]; + pind -= xout[j]; + } else { + break; + } } - old_strides.push_back(temp_old); - new_strides.push_back(temp_new); } - - TransposeFunc(input_x->numel(), input_x_data, axis, old_strides, - new_strides, out_data); } } // namespace operators