From d6791276a1ef4eae38956fdb168ba467333f27d4 Mon Sep 17 00:00:00 2001 From: dingminghui Date: Tue, 26 May 2020 15:59:12 +0800 Subject: [PATCH] refactor: abstract function to generate axes trans vector --- lite/kernels/mlu/bridges/concat_op.cc | 10 ++---- lite/kernels/mlu/bridges/flatten_op.cc | 38 ++--------------------- lite/kernels/mlu/bridges/reshape_op.cc | 38 ++--------------------- lite/kernels/mlu/bridges/slice_op.cc | 7 +---- lite/kernels/mlu/bridges/slice_op_test.cc | 32 ++++++------------- lite/kernels/mlu/bridges/softmax_op.cc | 12 +++---- lite/kernels/mlu/bridges/transpose_op.cc | 15 ++------- lite/kernels/mlu/bridges/utility.h | 28 +++++++++++++++-- 8 files changed, 49 insertions(+), 131 deletions(-) diff --git a/lite/kernels/mlu/bridges/concat_op.cc b/lite/kernels/mlu/bridges/concat_op.cc index 76037c3358..1d56663993 100644 --- a/lite/kernels/mlu/bridges/concat_op.cc +++ b/lite/kernels/mlu/bridges/concat_op.cc @@ -45,13 +45,9 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto dims = output_dims.size(); int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; CHECK_LT(axis, dims) << "Unsupport dims in mlu concat"; - std::vector nchw2nhwc_axis(dims); - nchw2nhwc_axis[0] = 0; - if (dims > 1) nchw2nhwc_axis[1] = dims - 1; - for (size_t i = 2; i < dims; ++i) { - nchw2nhwc_axis[i] = i - 1; - } - int nhwc_axis = nchw2nhwc_axis[axis]; + // value of nhwc2nchw_axis is index of nhwc + // order of nhwc2nchw_axis is nchw + int nhwc_axis = GetAxisNHWC2NCHW(dims)[axis]; auto output_tensor = graph->AddNode( out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); diff --git a/lite/kernels/mlu/bridges/flatten_op.cc b/lite/kernels/mlu/bridges/flatten_op.cc index 5b36cba6aa..faf7e6fd28 100644 --- a/lite/kernels/mlu/bridges/flatten_op.cc +++ b/lite/kernels/mlu/bridges/flatten_op.cc @@ -38,24 +38,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { // ================== Trans1: NHWC => NCHW =========================== auto input_tensor = graph->GetNode(x_var_name); - // std::vector nhwc_to_nchw_axis = {0, 3, 1, 2}; - std::vector trans_1_axis; - switch (x->dims().size()) { - case 4: - trans_1_axis = {0, 3, 1, 2}; - break; - case 3: - trans_1_axis = {0, 2, 1}; - break; - case 2: - trans_1_axis = {0, 1}; - break; - case 1: - trans_1_axis = {0}; - break; - default: - break; - } + auto trans_1_axis = std::move(GetAxisNHWC2NCHW(x->dims().size())); auto trans1_out = graph->AddNode(x_var_name + ".trans.i", x->dims().Vectorize(), CNML_TENSOR, @@ -95,24 +78,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { // ======================= Flatten End =================================== // ================== Trans2: NCHW => NHWC =============================== - // std::vector nchw_to_nhwc_axis = {0, 2, 3, 1}; - std::vector trans_2_axis; - switch (output->dims().size()) { - case 4: - trans_2_axis = {0, 2, 3, 1}; - break; - case 3: - trans_2_axis = {0, 2, 1}; - break; - case 2: - trans_2_axis = {0, 1}; - break; - case 1: - trans_2_axis = {0}; - break; - default: - break; - } + auto trans_2_axis = std::move(GetAxisNCHW2NHWC(output->dims().size())); auto output_tensor = graph->AddNode( out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); cnmlBaseOp_t trans2_op{nullptr}; diff --git a/lite/kernels/mlu/bridges/reshape_op.cc b/lite/kernels/mlu/bridges/reshape_op.cc index 4a54e64dfe..0b47322b34 100644 --- a/lite/kernels/mlu/bridges/reshape_op.cc +++ b/lite/kernels/mlu/bridges/reshape_op.cc @@ -38,24 +38,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { // ================== Trans1: NHWC => NCHW =========================== auto input_tensor = graph->GetNode(x_var_name); - // std::vector nhwc_to_nchw_axis = {0, 3, 1, 2}; - std::vector trans_1_axis; - switch (x->dims().size()) { - case 4: - trans_1_axis = {0, 3, 1, 2}; - break; - case 3: - trans_1_axis = {0, 2, 1}; - break; - case 2: - trans_1_axis = {0, 1}; - break; - case 1: - trans_1_axis = {0}; - break; - default: - break; - } + auto trans_1_axis = std::move(GetAxisNHWC2NCHW(x->dims().size())); auto trans1_out = graph->AddNode(x_var_name + ".trans.i", x->dims().Vectorize(), CNML_TENSOR, @@ -95,24 +78,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { // ======================= Reshape op End =================================== // ================== Trans2: NCHW => NHWC =============================== - // std::vector nchw_to_nhwc_axis = {0, 2, 3, 1}; - std::vector trans_2_axis; - switch (output->dims().size()) { - case 4: - trans_2_axis = {0, 2, 3, 1}; - break; - case 3: - trans_2_axis = {0, 2, 1}; - break; - case 2: - trans_2_axis = {0, 1}; - break; - case 1: - trans_2_axis = {0}; - break; - default: - break; - } + auto trans_2_axis = std::move(GetAxisNCHW2NHWC(output->dims().size())); auto output_tensor = graph->AddNode( out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); cnmlBaseOp_t trans2_op{nullptr}; diff --git a/lite/kernels/mlu/bridges/slice_op.cc b/lite/kernels/mlu/bridges/slice_op.cc index deb48f903d..067d110bf4 100644 --- a/lite/kernels/mlu/bridges/slice_op.cc +++ b/lite/kernels/mlu/bridges/slice_op.cc @@ -53,12 +53,7 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { std::vector begin_index(input_shape.size(), 0); std::vector end_index(input_shape.size()); std::vector strides(input_shape.size(), 1); - std::vector nhwc2nchw_axis(input_shape.size()); - nhwc2nchw_axis[0] = 0; - if (input_shape.size() > 1) nhwc2nchw_axis[1] = input_shape.size() - 1; - for (size_t i = 2; i < input_shape.size(); ++i) { - nhwc2nchw_axis[i] = i - 1; - } + auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW(input_shape.size())); for (size_t i = 0; i < input_shape.size(); ++i) { end_index[nhwc2nchw_axis[i]] = input_shape[i]; } diff --git a/lite/kernels/mlu/bridges/slice_op_test.cc b/lite/kernels/mlu/bridges/slice_op_test.cc index 29f698657c..8c09e84861 100644 --- a/lite/kernels/mlu/bridges/slice_op_test.cc +++ b/lite/kernels/mlu/bridges/slice_op_test.cc @@ -108,31 +108,19 @@ static void test_case(std::vector x_shape, std::vector out_ref(out->data_size(), 0); slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data()); - std::vector nhwc2nchw_axis(x_shape.size()); - nhwc2nchw_axis[0] = 0; - if (x_shape.size() > 1) nhwc2nchw_axis[1] = x_shape.size() - 1; - for (size_t i = 2; i < x_shape.size(); ++i) { - nhwc2nchw_axis[i] = i - 1; - } - - std::vector nchw2nhwc_axis(x_shape.size()); - nchw2nhwc_axis[0] = 0; - for (size_t i = 1; i < x_shape.size() - 1; ++i) { - nchw2nhwc_axis[i] = i + 1; - } - if (x_shape.size() > 1) nchw2nhwc_axis[x_shape.size() - 1] = 1; - auto type_cast = [](int64_t in) { return static_cast(in); }; std::vector i_dims; std::transform( x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast); + auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC(x_shape.size())); + Tensor input_x; input_x.Resize(x->dims()); - transpose(x->mutable_data(), - input_x.mutable_data(), - i_dims, - nchw2nhwc_axis); + transpose(x->mutable_data(), + input_x.mutable_data(), + i_dims, + nchw2nhwc_axis); x->CopyDataFrom(input_x); auto op = CreateOp(opdesc, &scope); @@ -145,10 +133,10 @@ static void test_case(std::vector x_shape, for (size_t i = 0; i < os.size(); ++i) { o_dims[i] = os[nchw2nhwc_axis[i]]; } - transpose(out->mutable_data(), - output_trans.mutable_data(), - o_dims, - nhwc2nchw_axis); + transpose(out->mutable_data(), + output_trans.mutable_data(), + o_dims, + GetAxisNHWC2NCHW(x_shape.size())); auto out_data = output_trans.mutable_data(); for (int i = 0; i < out->dims().production(); i++) { diff --git a/lite/kernels/mlu/bridges/softmax_op.cc b/lite/kernels/mlu/bridges/softmax_op.cc index 36732fd5c2..b1b621c1ef 100644 --- a/lite/kernels/mlu/bridges/softmax_op.cc +++ b/lite/kernels/mlu/bridges/softmax_op.cc @@ -38,13 +38,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto x_shape = scope->FindVar(x_var_name)->GetMutable()->dims().Vectorize(); - // nchw axis to nhwc aixs - std::vector nchw2nhwc_axis(x_shape.size()); - nchw2nhwc_axis[0] = 0; - if (x_shape.size() > 1) nchw2nhwc_axis[1] = x_shape.size() - 1; - for (size_t i = 2; i < x_shape.size(); ++i) { - nchw2nhwc_axis[i] = i - 1; - } + // nchw axis to nhwc axis int axis = 1; if (op_info->HasAttr("axis")) { axis = op_info->GetAttr("axis"); @@ -52,7 +46,9 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { axis = output_dims.size() + axis; } } - int nhwc_axis = nchw2nhwc_axis[axis]; + // value of nhwc2nchw_axis is index of nhwc + // order of nhwc2nchw_axis is nchw + int nhwc_axis = GetAxisNHWC2NCHW(x_shape.size())[axis]; auto output_tensor = graph->AddNode( out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); diff --git a/lite/kernels/mlu/bridges/transpose_op.cc b/lite/kernels/mlu/bridges/transpose_op.cc index 43f25550c2..b6caeb3613 100644 --- a/lite/kernels/mlu/bridges/transpose_op.cc +++ b/lite/kernels/mlu/bridges/transpose_op.cc @@ -24,19 +24,8 @@ namespace mlu { std::vector axis_to_nhwc(const std::vector& axis) { std::vector new_axis(axis.size()); - std::vector nhwc2nchw_axis(axis.size()); - nhwc2nchw_axis[0] = 0; - if (axis.size() > 1) nhwc2nchw_axis[1] = axis.size() - 1; - for (size_t i = 2; i < axis.size(); ++i) { - nhwc2nchw_axis[i] = i - 1; - } - - std::vector nchw2nhwc_axis(axis.size()); - nchw2nhwc_axis[0] = 0; - for (size_t i = 1; i < axis.size() - 1; ++i) { - nchw2nhwc_axis[i] = i + 1; - } - if (axis.size() > 1) nchw2nhwc_axis[axis.size() - 1] = 1; + auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW(axis.size())); + auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC(axis.size())); for (size_t i = 0; i < new_axis.size(); ++i) { new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]]; diff --git a/lite/kernels/mlu/bridges/utility.h b/lite/kernels/mlu/bridges/utility.h index 78f862c0d3..fd1e5eb265 100644 --- a/lite/kernels/mlu/bridges/utility.h +++ b/lite/kernels/mlu/bridges/utility.h @@ -44,12 +44,12 @@ void transpose(dtype* input_data, int new_index = -1; std::vector shape; std::vector expand_axis; - if (input_shape.size() < 5) { - for (int i = 0; i < 5 - input_shape.size(); i++) { + if (input_shape.size() < 5u) { + for (size_t i = 0; i < 5 - input_shape.size(); i++) { shape.push_back(1); expand_axis.push_back(i); } - for (int i = 0; i < input_shape.size(); i++) { + for (size_t i = 0; i < input_shape.size(); i++) { shape.push_back(input_shape[i]); expand_axis.push_back(axis[i] + 5 - input_shape.size()); } @@ -154,6 +154,28 @@ inline const std::vector DimNCHW2NHWC( } } +template +inline std::vector GetAxisNHWC2NCHW(size_t n_dims) { + std::vector nhwc2nchw_axis(n_dims); + nhwc2nchw_axis[0] = 0; + if (n_dims > 1) nhwc2nchw_axis[1] = n_dims - 1; + for (size_t i = 2; i < n_dims; ++i) { + nhwc2nchw_axis[i] = i - 1; + } + return nhwc2nchw_axis; +} + +template +inline std::vector GetAxisNCHW2NHWC(size_t n_dims) { + std::vector nchw2nhwc_axis(n_dims); + nchw2nhwc_axis[0] = 0; + for (size_t i = 1; i < n_dims - 1; ++i) { + nchw2nhwc_axis[i] = i + 1; + } + if (n_dims > 1) nchw2nhwc_axis[n_dims - 1] = 1; + return nchw2nhwc_axis; +} + template struct MLUTypeTraits { /* using type = void; */ -- GitLab