From 4030e756fbe039563beaf979231f661b38bc7c55 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Tue, 16 Jun 2020 21:59:02 -0500 Subject: [PATCH] [arm]fix concat axis < 0 compute error problem (#3802) * fix concatt axis < 0 errorr,ttest=develop * fix format. test=develop --- lite/core/mir/fusion/conv_activation_fuse_pass.cc | 8 ++++---- lite/kernels/arm/concat_compute.cc | 9 ++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index 68c07c0ffd..fa89bc2a5f 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -25,21 +25,21 @@ namespace mir { void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { std::vector act_types{"relu"}; bool has_int8 = false; - bool has_arm_float = false; + bool has_arm = false; bool has_cuda = false; for (auto& place : graph->valid_places()) { if (place.precision == PRECISION(kInt8)) { has_int8 = true; } - if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { - has_arm_float = true; + if (place.target == TARGET(kARM)) { + has_arm = true; } if (place.target == TARGET(kCUDA)) { has_cuda = true; } } - if (!has_int8 && has_arm_float) { + if (has_arm) { act_types.push_back("relu6"); act_types.push_back("leaky_relu"); } diff --git a/lite/kernels/arm/concat_compute.cc b/lite/kernels/arm/concat_compute.cc index dc78e1b955..9ab4ca54bb 100644 --- a/lite/kernels/arm/concat_compute.cc +++ b/lite/kernels/arm/concat_compute.cc @@ -52,11 +52,7 @@ void ConcatFunc(const std::vector inputs, output_offset += in_stride[0]; } } else { - std::vector inputs_concat(inputs.size()); - for (int j = 0; j < inputs.size(); ++j) { - inputs_concat[j] = inputs[j]; - } - lite::arm::math::concat_func(inputs_concat, axis, out); + lite::arm::math::concat_func(inputs, axis, out); } } @@ -71,6 +67,9 @@ void ConcatCompute::Run() { auto* axis_tensor_data = axis_tensor->data(); axis = axis_tensor_data[0]; } + if (axis < 0) { + axis += inputs[0]->dims().size(); + } switch (inputs.front()->precision()) { case PRECISION(kFloat): -- GitLab