未验证 提交 875d4563 编写于 作者: H hong19860320 提交者: GitHub

[Core] Fix the missing of the input and output scale after the...

[Core] Fix the missing of the input and output scale after the lite_elementwise_activation_fuse_pass is applied (#4066)
上级 7e92cc7f
...@@ -75,9 +75,8 @@ void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -75,9 +75,8 @@ void ElementwiseActivationFuser::InsertNewNode(SSAGraph* graph,
} }
cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("elt")->stmt()->op_info(); auto op_desc = *matched.at("elt")->stmt()->op_info();
auto* act_op_desc = matched.at("act")->stmt()->op_info();
cpp::OpDesc op_desc;
if (eltwise_type_ == "elementwise_add") { if (eltwise_type_ == "elementwise_add") {
op_desc.SetType("fusion_elementwise_add_activation"); op_desc.SetType("fusion_elementwise_add_activation");
} else if (eltwise_type_ == "elementwise_sub") { } else if (eltwise_type_ == "elementwise_sub") {
...@@ -87,13 +86,12 @@ cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -87,13 +86,12 @@ cpp::OpDesc ElementwiseActivationFuser::GenOpDesc(const key2nodes_t& matched) {
} else { } else {
LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_; LOG(FATAL) << "not supported elementwise_type: " << eltwise_type_;
} }
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
op_desc.SetAttr("axis", desc->GetAttr<int>("axis"));
op_desc.SetAttr("act_type", act_type_); op_desc.SetAttr("act_type", act_type_);
auto& out_name = matched.at("output")->arg()->name;
op_desc.SetOutput("Out", {out_name});
if (act_op_desc->HasOutputScale(out_name)) {
op_desc.SetOutputScale(out_name, act_op_desc->GetOutputScale(out_name));
}
return op_desc; return op_desc;
} }
......
...@@ -61,20 +61,23 @@ void ScaleActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -61,20 +61,23 @@ void ScaleActivationFuser::InsertNewNode(SSAGraph* graph,
} }
cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("scale")->stmt()->op_info(); auto op_desc = *matched.at("scale")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name}); auto* act_op_desc = matched.at("act")->stmt()->op_info();
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("activation_type", act_type_); op_desc.SetAttr("activation_type", act_type_);
if (act_type_ == "relu") { if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true); op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") { } else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold"); float alpha = act_op_desc->GetAttr<float>("threshold");
op_desc.SetAttr("alpha", alpha); op_desc.SetAttr("alpha", alpha);
} else if (act_type_ == "leaky_relu") { } else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha"); float alpha = act_op_desc->GetAttr<float>("alpha");
op_desc.SetAttr("alpha", alpha); op_desc.SetAttr("alpha", alpha);
} }
auto& out_name = matched.at("output")->arg()->name;
op_desc.SetOutput("Out", {out_name});
if (act_op_desc->HasOutputScale(out_name)) {
op_desc.SetOutputScale(out_name, act_op_desc->GetOutputScale(out_name));
}
return op_desc; return op_desc;
} }
......
...@@ -32,6 +32,7 @@ void QuantizedOpAttributesInferencePass::Apply( ...@@ -32,6 +32,7 @@ void QuantizedOpAttributesInferencePass::Apply(
// Only for fully quantized model which is only supported by MTK and RK NPU. // Only for fully quantized model which is only supported by MTK and RK NPU.
// Replace the output_scale with the input_scale of the adjacent quantized // Replace the output_scale with the input_scale of the adjacent quantized
// ops, and fix the missing of the attribute 'enable_int8'. // ops, and fix the missing of the attribute 'enable_int8'.
VLOG(5) << "\n" << Visualize(graph.get());
for (auto& op_node : graph->StmtTopologicalOrder()) { for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue; if (!op_node->IsStmt()) continue;
auto& inst = op_node->AsStmt(); auto& inst = op_node->AsStmt();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册