From cf6886078857768c274d4f76889124976160434b Mon Sep 17 00:00:00 2001 From: jiaopu Date: Wed, 13 May 2020 17:43:27 +0800 Subject: [PATCH] Add layout op --- lite/core/mir/mlu_postprocess_pass.cc | 20 +- lite/kernels/mlu/bridges/CMakeLists.txt | 3 + lite/kernels/mlu/bridges/layout_op.cc | 108 ++++++++++ lite/kernels/mlu/bridges/layout_op_test.cc | 190 ++++++++++++++++++ lite/kernels/mlu/bridges/paddle_use_bridges.h | 1 + lite/kernels/mlu/bridges/tensor.cc | 5 +- lite/kernels/mlu/bridges/tensor.h | 1 + lite/kernels/mlu/bridges/test_helper.cc | 14 +- lite/kernels/mlu/bridges/test_helper.h | 3 +- lite/kernels/mlu/bridges/utility.h | 102 +++++++--- 10 files changed, 407 insertions(+), 40 deletions(-) create mode 100644 lite/kernels/mlu/bridges/layout_op.cc create mode 100644 lite/kernels/mlu/bridges/layout_op_test.cc diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index f61b8e1b25..a4fc0d8d4c 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -620,7 +620,8 @@ std::string CheckInputAndInsert(Scope* scope, auto layout_op = block_desc->AddOp(); auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); scope->Var(layout_arg_name); - VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; + VLOG(5) << "insert layout in subgraph, arg tensor name: " + << layout_arg_name; layout_op->SetType("layout"); layout_op->SetInput("Input", {cur_node}); layout_op->SetOutput("Out", {layout_arg_name}); @@ -663,7 +664,8 @@ std::string CheckOutputAndInsert(Scope* scope, if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); scope->Var(layout_arg_name); - VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; + VLOG(5) << "insert layout in subgraph, arg tensor name: " + << layout_arg_name; layout_op = block_desc->AddOp(); layout_op->SetType("layout"); layout_op->SetInput("Input", {layout_arg_name}); @@ -709,16 +711,22 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, auto input_name = input->AsArg().name; if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { i_names.emplace_back(input_name); - node_replace[input_name] = CheckInputAndInsert( - op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type); + node_replace[input_name] = CheckInputAndInsert(op->scope(), + new_block_desc, + input_name, + input->AsArg().type, + subgraph_type); } } for (auto& output : subgraph_node->outlinks) { auto output_name = output->AsArg().name; if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { o_names.emplace_back(output_name); - node_replace[output_name] = CheckOutputAndInsert( - op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type); + node_replace[output_name] = CheckOutputAndInsert(op->scope(), + block_desc, + output_name, + output->AsArg().type, + subgraph_type); } } diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index 5475b1686c..469a16e4bb 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgrap lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_cast_op_mlu SRCS cast_op.cc DEPS ${subgraph_bridge_deps_mlu}) +lite_cc_library(subgraph_bridge_layout_op_mlu SRCS layout_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu}) lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu}) set(mlu_subgraph_bridges @@ -44,6 +45,7 @@ set(mlu_subgraph_bridges subgraph_bridge_slice_op_mlu subgraph_bridge_split_op_mlu subgraph_bridge_cast_op_mlu + subgraph_bridge_layout_op_mlu subgraph_bridge_argmax_op_mlu subgraph_bridge_squeeze_op_mlu CACHE INTERNAL "mlu_subgraph_bridges") @@ -71,6 +73,7 @@ lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope o lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +lite_cc_test(test_layout_converter_mlu SRCS layout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_cast_converter_mlu SRCS cast_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) diff --git a/lite/kernels/mlu/bridges/layout_op.cc b/lite/kernels/mlu/bridges/layout_op.cc new file mode 100644 index 0000000000..9eebccbcd6 --- /dev/null +++ b/lite/kernels/mlu/bridges/layout_op.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int LayoutConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + auto x_var_name = op_info->Input("Input").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + std::shared_ptr output_tensor; + + CHECK(graph->HasNode(x_var_name)); + std::vector axis; + auto x_tensor = graph->GetNode(x_var_name); + auto x_data_order = x_tensor->dorder(); + auto x_dims = x->dims().Vectorize(); + if (x_data_order == CNML_NCHW) { + switch (x_dims.size()) { + case 2: + axis = {0, 1}; + break; + case 3: + axis = {0, 2, 1}; + break; + case 4: + axis = {0, 2, 3, 1}; + break; + case 5: + axis = {0, 2, 3, 4, 1}; + break; + default: + CHECK(0) << "Unsupport shape"; + } + output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); + } else { + switch (x_dims.size()) { + case 2: + axis = {0, 1}; + break; + case 3: + axis = {0, 2, 1}; + break; + case 4: + axis = {0, 3, 1, 2}; + break; + case 5: + axis = {0, 4, 1, 2, 3}; + break; + default: + CHECK(0) << "Unsupport shpae"; + } + output_tensor = graph->AddNode(out_var_name, + output_dims, + CNML_TENSOR, + CNML_NCHW, + graph->FPType(), + CNML_NCHW); + } + cnmlBaseOp_t layout_op; + cnmlNdTransposeOpParam_t transpose_param; + CNML_CALL( + cnmlCreateNdTransposeOpParam(&transpose_param, axis.data(), axis.size())); + CNML_CALL(cnmlCreateNdTransposeProOp(&layout_op, + x_tensor->mlu_tensor(), + output_tensor->mlu_tensor(), + transpose_param)); + graph->FuseOp(layout_op); + CNML_CALL(cnmlDestroyBaseOp(&layout_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(layout, + kMLU, + paddle::lite::subgraph::mlu::LayoutConverter); diff --git a/lite/kernels/mlu/bridges/layout_op_test.cc b/lite/kernels/mlu/bridges/layout_op_test.cc new file mode 100644 index 0000000000..a3a39d9177 --- /dev/null +++ b/lite/kernels/mlu/bridges/layout_op_test.cc @@ -0,0 +1,190 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/layout_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +void test_layout_NHWC2NCHW(std::vector input_shape) { + // prepare input&output variables + std::string x_var_name = "input"; + std::string out_var_name = "out"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + x->Resize(DDim(input_shape)); + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("layout"); + opdesc.SetInput("Input", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + auto op = CreateOp(opdesc, &scope); + + // execute reference implementation and save to output tensor + Tensor input; + input.Resize(DDim(input_shape)); + switch (input_shape.size()) { + case 2: + transpose( + x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), static_cast(input_shape[1])}, + {0, 1}); + break; + case 3: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[1])}, + {0, 2, 1}); + break; + case 4: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[1])}, + {0, 3, 1, 2}); + break; + case 5: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[4]), + static_cast(input_shape[1])}, + {0, 4, 1, 2, 3}); + break; + default: + CHECK(0) << "Unsupport"; + } + auto* x_data = input.mutable_data(); + LaunchOp(op, {x_var_name}, {out_var_name}); + + // compare results + auto* out_data = out->mutable_data(); + + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], x_data[i], 5e-4); + } +} + +void test_layout_NCHW2NHWC(std::vector input_shape) { + // prepare input&output variables + std::string x_var_name = "input"; + std::string out_var_name = "out"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + x->Resize(DDim(input_shape)); + // initialize input&output data + FillTensor(x); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("layout"); + opdesc.SetInput("Input", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + auto op = CreateOp(opdesc, &scope); + + // execute reference implementation and save to output tensor + Tensor input; + input.Resize(DDim(input_shape)); + switch (input_shape.size()) { + case 2: + transpose( + x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), static_cast(input_shape[1])}, + {0, 1}); + break; + case 3: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2])}, + {0, 2, 1}); + break; + case 4: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3])}, + {0, 2, 3, 1}); + break; + case 5: + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3]), + static_cast(input_shape[4])}, + {0, 2, 3, 4, 1}); + break; + default: + CHECK(0) << "Unsupport"; + } + auto* x_data = input.mutable_data(); + LaunchOp(op, {x_var_name}, {out_var_name}, CNML_NCHW); + + // compare results + auto* out_data = out->mutable_data(); + + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], x_data[i], 5e-4); + } +} + +TEST(MLUBridges, layout) { + test_layout_NHWC2NCHW({12, 32, 4}); + test_layout_NHWC2NCHW({12, 32, 44, 3}); + test_layout_NHWC2NCHW({12, 32, 44, 3, 6}); + test_layout_NCHW2NHWC({12, 32, 55}); + test_layout_NCHW2NHWC({12, 32, 44, 3}); + test_layout_NCHW2NHWC({12, 32, 44, 3, 8}); + test_layout_NHWC2NCHW({12, 32}); + test_layout_NCHW2NHWC({12, 32}); +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(layout, kMLU); diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index 8faa7cc478..f286bb66fd 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -35,6 +35,7 @@ USE_SUBGRAPH_BRIDGE(dropout, kMLU); USE_SUBGRAPH_BRIDGE(argmax, kMLU); USE_SUBGRAPH_BRIDGE(split, kMLU); USE_SUBGRAPH_BRIDGE(cast, kMLU); +USE_SUBGRAPH_BRIDGE(layout, kMLU); USE_SUBGRAPH_BRIDGE(slice, kMLU); USE_SUBGRAPH_BRIDGE(squeeze, kMLU); USE_SUBGRAPH_BRIDGE(squeeze2, kMLU); diff --git a/lite/kernels/mlu/bridges/tensor.cc b/lite/kernels/mlu/bridges/tensor.cc index 02d6149520..b96be8ac1c 100644 --- a/lite/kernels/mlu/bridges/tensor.cc +++ b/lite/kernels/mlu/bridges/tensor.cc @@ -246,7 +246,10 @@ void MLUTensor::remember(const std::vector& shape, break; } } - dim_ = shape_.size(); + auto shape_NCHW = DimNHWC2NCHW(shape_); + shape_NCHW.erase(shape_NCHW.begin() + shape.size(), shape_NCHW.end()); + dim_ = shape_NCHW.size(); + shape_ = DimNCHW2NHWC(shape_NCHW); } void MLUTensor::Create() { diff --git a/lite/kernels/mlu/bridges/tensor.h b/lite/kernels/mlu/bridges/tensor.h index 5411b3e081..22268f69ba 100644 --- a/lite/kernels/mlu/bridges/tensor.h +++ b/lite/kernels/mlu/bridges/tensor.h @@ -59,6 +59,7 @@ class MLUTensor { ~MLUTensor(); void ToFile(std::string file_name); + cnmlDataOrder_t dorder() { return data_order_; } private: cnmlTensor_t mlu_tensor_; diff --git a/lite/kernels/mlu/bridges/test_helper.cc b/lite/kernels/mlu/bridges/test_helper.cc index a0ddc1b724..5c27bf3d05 100644 --- a/lite/kernels/mlu/bridges/test_helper.cc +++ b/lite/kernels/mlu/bridges/test_helper.cc @@ -27,7 +27,8 @@ namespace mlu { template void PrepareInput(Graph* graph, const std::string& input_name, - Tensor* input_tensor) { + Tensor* input_tensor, + cnmlDataOrder_t order) { thread_local Tensor temp_input; temp_input.Resize(input_tensor->dims().Vectorize()); temp_input.CopyDataFrom(*input_tensor); @@ -38,7 +39,7 @@ void PrepareInput(Graph* graph, CNML_TENSOR, CNML_NCHW, MLUTypeTraits::cnml_type, - CNML_NHWC, + order, reinterpret_cast( input_tensor->template mutable_data(TARGET(kMLU)))); CHECK(input_node); @@ -50,7 +51,8 @@ void PrepareInput(Graph* graph, void LaunchOp(const std::shared_ptr op, const std::vector& input_var_names, - const std::vector& output_var_names) { + const std::vector& output_var_names, + cnmlDataOrder_t order) { CNRT_CALL(cnrtInit(0)); lite::SetMluDevice(0); cnrtQueue_t queue_; @@ -77,9 +79,9 @@ void LaunchOp(const std::shared_ptr op, auto data_type = input_tensor->precision(); switch (data_type) { -#define PREPARE_INPUT(type__) \ - case PRECISION(type__): \ - PrepareInput(&graph, input_name, input_tensor); \ +#define PREPARE_INPUT(type__) \ + case PRECISION(type__): \ + PrepareInput(&graph, input_name, input_tensor, order); \ break; PREPARE_INPUT(kFP16) PREPARE_INPUT(kFloat) diff --git a/lite/kernels/mlu/bridges/test_helper.h b/lite/kernels/mlu/bridges/test_helper.h index 4da9e72dfc..36fe6f1efa 100644 --- a/lite/kernels/mlu/bridges/test_helper.h +++ b/lite/kernels/mlu/bridges/test_helper.h @@ -58,7 +58,8 @@ void FillTensor(Tensor* x, void LaunchOp(const std::shared_ptr op, const std::vector& input_var_names, - const std::vector& output_var_names); + const std::vector& output_var_names, + cnmlDataOrder_t order = CNML_NHWC); } // namespace mlu } // namespace subgraph diff --git a/lite/kernels/mlu/bridges/utility.h b/lite/kernels/mlu/bridges/utility.h index d0ed523301..38f3c734b3 100644 --- a/lite/kernels/mlu/bridges/utility.h +++ b/lite/kernels/mlu/bridges/utility.h @@ -47,22 +47,74 @@ void transpose(dtype input_data, std::vector axis) { int old_index = -1; int new_index = -1; - int dim[4] = {0}; - std::vector shape = input_shape; - for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { - for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { - for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { - for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { - old_index = dim[0] * shape[1] * shape[2] * shape[3] + - dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3]; - new_index = - dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] + - dim[axis[1]] * shape[axis[2]] * shape[axis[3]] + - dim[axis[2]] * shape[axis[3]] + dim[axis[3]]; + if (input_shape.size() == 2) { + int dim[2] = {0}; + std::vector shape = input_shape; + for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { + for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { + old_index = dim[0] * shape[1] + dim[1]; + new_index = dim[axis[0]] * shape[axis[1]] + dim[axis[1]]; + output_data[new_index] = input_data[old_index]; + } + } + } else if (input_shape.size() == 3) { + int dim[3] = {0}; + std::vector shape = input_shape; + for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { + for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { + for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { + old_index = dim[0] * shape[1] * shape[2] + dim[1] * shape[2] + dim[2]; + new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] + + dim[axis[1]] * shape[axis[2]] + dim[axis[2]]; output_data[new_index] = input_data[old_index]; } } } + } else if (input_shape.size() == 4) { + int dim[4] = {0}; + std::vector shape = input_shape; + for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { + for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { + for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { + for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { + old_index = dim[0] * shape[1] * shape[2] * shape[3] + + dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + + dim[3]; + new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * + shape[axis[3]] + + dim[axis[1]] * shape[axis[2]] * shape[axis[3]] + + dim[axis[2]] * shape[axis[3]] + dim[axis[3]]; + output_data[new_index] = input_data[old_index]; + } + } + } + } + } else if (input_shape.size() == 5) { + int dim[5] = {0}; + std::vector shape = input_shape; + for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) { + for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) { + for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) { + for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) { + for (dim[4] = 0; dim[4] < input_shape[4]; dim[4]++) { + old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] + + dim[1] * shape[2] * shape[3] * shape[4] + + dim[2] * shape[3] * shape[4] + dim[3] * shape[4] + + dim[4]; + new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * + shape[axis[3]] * shape[axis[4]] + + dim[axis[1]] * shape[axis[2]] * shape[axis[3]] * + shape[axis[4]] + + dim[axis[2]] * shape[axis[3]] * shape[axis[4]] + + dim[axis[3]] * shape[axis[4]] + dim[axis[4]]; + output_data[new_index] = input_data[old_index]; + } + } + } + } + } + + } else { } } @@ -103,41 +155,39 @@ inline const ::paddle::lite::DDimLite DimNCHW2NHWC( std::vector({dim[0], dim[2], dim[3], dim[1]})); } -inline const std::vector DimNHWC2NCHW( - const std::vector& dim) { +template +inline const std::vector DimNHWC2NCHW( + const std::vector& dim) { switch (dim.size()) { case 1: return dim; case 2: return dim; case 3: - return std::vector({dim[0], dim[2], dim[1]}); + return std::vector({dim[0], dim[2], dim[1]}); case 4: - return std::vector( - {dim[0], dim[3], dim[1], dim[2]}); + return std::vector({dim[0], dim[3], dim[1], dim[2]}); case 5: - return std::vector( - {dim[0], dim[4], dim[1], dim[2], dim[3]}); + return std::vector({dim[0], dim[4], dim[1], dim[2], dim[3]}); default: CHECK(0) << "unsupport dimension"; } } -inline const std::vector DimNCHW2NHWC( - const std::vector& dim) { +template +inline const std::vector DimNCHW2NHWC( + const std::vector& dim) { switch (dim.size()) { case 1: return dim; case 2: return dim; case 3: - return std::vector({dim[0], dim[2], dim[1]}); + return std::vector({dim[0], dim[2], dim[1]}); case 4: - return std::vector( - {dim[0], dim[2], dim[3], dim[1]}); + return std::vector({dim[0], dim[2], dim[3], dim[1]}); case 5: - return std::vector( - {dim[0], dim[2], dim[3], dim[4], dim[1]}); + return std::vector({dim[0], dim[2], dim[3], dim[4], dim[1]}); default: CHECK(0) << "unsupport dimension"; } -- GitLab