From 875d456392269a927273eb2b0fdb7585c1ccb791 Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Fri, 7 Aug 2020 18:11:06 +0800 Subject: [PATCH] [Core] Fix the missing of the input and output scale after the lite_elementwise_activation_fuse_pass is applied (#4066) --- .../fusion/elementwise_add_activation_fuser.cc | 16 +++++++--------- lite/core/mir/fusion/scale_activation_fuser.cc | 15 +++++++++------ .../quantized_op_attributes_inference_pass.cc | 1 + 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/lite/core/mir/fusion/elementwise_add_activation_fuser.cc b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc index 28081748a7..2e401fa62e 100644 --- a/lite/core/mir/fusion/elementwise_add_activation_fuser.cc +++ b/lite/core/mir/fusion/elementwise_add_activation_fuser.cc @@ -75,9 +75,8 @@ void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph, } cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) { - auto* desc = matched.at("elt")->stmt()->op_info(); - - cpp::OpDesc op_desc; + auto op_desc = *matched.at("elt")->stmt()->op_info(); + auto* act_op_desc = matched.at("act")->stmt()->op_info(); if (eltwise_type_ == "elementwise_add") { op_desc.SetType("fusion_elementwise_add_activation"); } else if (eltwise_type_ == "elementwise_sub") { @@ -87,13 +86,12 @@ cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) { } else { LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_; } - - op_desc.SetInput("X", {matched.at("x")->arg()->name}); - op_desc.SetInput("Y", {matched.at("y")->arg()->name}); - op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); - - op_desc.SetAttr("axis", desc->GetAttr("axis")); op_desc.SetAttr("act_type", act_type_); + auto& out_name = matched.at("output")->arg()->name; + op_desc.SetOutput("Out", {out_name}); + if (act_op_desc->HasOutputScale(out_name)) { + op_desc.SetOutputScale(out_name, act_op_desc->GetOutputScale(out_name)); + } return op_desc; } diff --git a/lite/core/mir/fusion/scale_activation_fuser.cc b/lite/core/mir/fusion/scale_activation_fuser.cc index 4f18099da8..b9ae3a9520 100644 --- a/lite/core/mir/fusion/scale_activation_fuser.cc +++ b/lite/core/mir/fusion/scale_activation_fuser.cc @@ -61,20 +61,23 @@ void ScaleActivationFuser::InsertNewNode(SSAGraph* graph, } cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) { - cpp::OpDesc op_desc = *matched.at("scale")->stmt()->op_info(); - op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); - cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info(); - + auto op_desc = *matched.at("scale")->stmt()->op_info(); + auto* act_op_desc = matched.at("act")->stmt()->op_info(); op_desc.SetAttr("activation_type", act_type_); if (act_type_ == "relu") { op_desc.SetAttr("fuse_relu", true); } else if (act_type_ == "relu6") { - float alpha = act_op_desc.GetAttr("threshold"); + float alpha = act_op_desc->GetAttr("threshold"); op_desc.SetAttr("alpha", alpha); } else if (act_type_ == "leaky_relu") { - float alpha = act_op_desc.GetAttr("alpha"); + float alpha = act_op_desc->GetAttr("alpha"); op_desc.SetAttr("alpha", alpha); } + auto& out_name = matched.at("output")->arg()->name; + op_desc.SetOutput("Out", {out_name}); + if (act_op_desc->HasOutputScale(out_name)) { + op_desc.SetOutputScale(out_name, act_op_desc->GetOutputScale(out_name)); + } return op_desc; } diff --git a/lite/core/mir/quantized_op_attributes_inference_pass.cc b/lite/core/mir/quantized_op_attributes_inference_pass.cc index 259447aa21..6cceb18d90 100644 --- a/lite/core/mir/quantized_op_attributes_inference_pass.cc +++ b/lite/core/mir/quantized_op_attributes_inference_pass.cc @@ -32,6 +32,7 @@ void QuantizedOpAttributesInferencePass::Apply( // Only for fully quantized model which is only supported by MTK and RK NPU. // Replace the output_scale with the input_scale of the adjacent quantized // ops, and fix the missing of the attribute 'enable_int8'. + VLOG(5) << "\n" << Visualize(graph.get()); for (auto& op_node : graph->StmtTopologicalOrder()) { if (!op_node->IsStmt()) continue; auto& inst = op_node->AsStmt(); -- GitLab