diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index ecf25087004addcdac6d4d0f3268a4d09291d73f..92ef0180ac431a43a0779f15df60d20a88069af3 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -30,47 +30,52 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { // releated nodes std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; + std::vector quant_nodes; for (auto& cur_node : graph->mutable_nodes()) { if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) { - // find input nodes and output nodes - std::list input_nodes = cur_node.inlinks; - std::list output_nodes = cur_node.outlinks; - CHECK_EQ(input_nodes.size(), 2); - CHECK_EQ(output_nodes.size(), 2); - - bool front_is_scale = input_nodes.front()->arg()->is_weight; - Node* input_scale_node = - front_is_scale ? input_nodes.front() : input_nodes.back(); - Node* input_act_node = - front_is_scale ? input_nodes.back() : input_nodes.front(); - front_is_scale = output_nodes.front()->arg()->is_weight; - Node* output_scale_node = - front_is_scale ? output_nodes.front() : output_nodes.back(); - Node* output_act_node = - front_is_scale ? output_nodes.back() : output_nodes.front(); + quant_nodes.push_back(&cur_node); + } + } + for (auto quant_node : quant_nodes) { + // find input nodes and output nodes + std::list input_nodes = quant_node->inlinks; + std::list output_nodes = quant_node->outlinks; + CHECK_EQ(input_nodes.size(), 2); + CHECK_EQ(output_nodes.size(), 2); - // relink nodes and save value to quantized_node - int bit_length = cur_node.stmt()->op_info()->GetAttr("bit_length"); - int range = ((1 << (bit_length - 1)) - 1); - auto* scope = cur_node.stmt()->op()->scope(); - auto scale_tensor = scope->FindVar(output_scale_node->arg()->name) - ->GetMutable(); - float scale_value = scale_tensor->data()[0] / range; + bool front_is_scale = input_nodes.front()->arg()->is_weight; + Node* input_scale_node = + front_is_scale ? input_nodes.front() : input_nodes.back(); + Node* input_act_node = + front_is_scale ? input_nodes.back() : input_nodes.front(); + front_is_scale = output_nodes.front()->arg()->is_weight; + Node* output_scale_node = + front_is_scale ? output_nodes.front() : output_nodes.back(); + Node* output_act_node = + front_is_scale ? output_nodes.back() : output_nodes.front(); - for (auto* quantized_node_ptr : output_act_node->outlinks) { - quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( - "bit_length", bit_length); - quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( - "input_scale", scale_value); - IR_NODE_LINK_TO(input_act_node, quantized_node_ptr) - RemoveDirectedLink(output_act_node, quantized_node_ptr); - } + // relink nodes and save value to quantized_node + int bit_length = quant_node->stmt()->op_info()->GetAttr("bit_length"); + int range = ((1 << (bit_length - 1)) - 1); + auto* scope = quant_node->stmt()->op()->scope(); + auto scale_tensor = scope->FindVar(output_scale_node->arg()->name) + ->GetMutable(); + float scale_value = scale_tensor->data()[0] / range; - // delete nodes and edges - std::unordered_set nodes2rm = { - input_scale_node, &cur_node, output_scale_node, output_act_node}; - GraphSafeRemoveNodes(graph.get(), nodes2rm); + auto outlinks = output_act_node->outlinks; + for (auto* quantized_node_ptr : outlinks) { + quantized_node_ptr->stmt()->mutable_op_info()->SetAttr("bit_length", + bit_length); + quantized_node_ptr->stmt()->mutable_op_info()->SetAttr( + "input_scale", scale_value); + IR_NODE_LINK_TO(input_act_node, quantized_node_ptr) + RemoveDirectedLink(output_act_node, quantized_node_ptr); } + + // delete nodes and edges + std::unordered_set nodes2rm = { + input_scale_node, quant_node, output_scale_node, output_act_node}; + GraphSafeRemoveNodes(graph.get(), nodes2rm); } // fuse quantized node and dequant node