未验证 提交 9cc7dfa8 编写于 作者: J juncaipeng 提交者: GitHub

Fix quant dequant fuse pass (#2190)


* fix bug for accessing the removed node, test=develop
上级 4d530acc
...@@ -30,47 +30,52 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -30,47 +30,52 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// releated nodes // releated nodes
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::vector<Node*> quant_nodes;
for (auto& cur_node : graph->mutable_nodes()) { for (auto& cur_node : graph->mutable_nodes()) {
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) { if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
// find input nodes and output nodes quant_nodes.push_back(&cur_node);
std::list<Node*> input_nodes = cur_node.inlinks; }
std::list<Node*> output_nodes = cur_node.outlinks; }
CHECK_EQ(input_nodes.size(), 2); for (auto quant_node : quant_nodes) {
CHECK_EQ(output_nodes.size(), 2); // find input nodes and output nodes
std::list<Node*> input_nodes = quant_node->inlinks;
bool front_is_scale = input_nodes.front()->arg()->is_weight; std::list<Node*> output_nodes = quant_node->outlinks;
Node* input_scale_node = CHECK_EQ(input_nodes.size(), 2);
front_is_scale ? input_nodes.front() : input_nodes.back(); CHECK_EQ(output_nodes.size(), 2);
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node bool front_is_scale = input_nodes.front()->arg()->is_weight;
int bit_length = cur_node.stmt()->op_info()->GetAttr<int>("bit_length"); Node* input_scale_node =
int range = ((1 << (bit_length - 1)) - 1); front_is_scale ? input_nodes.front() : input_nodes.back();
auto* scope = cur_node.stmt()->op()->scope(); Node* input_act_node =
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name) front_is_scale ? input_nodes.back() : input_nodes.front();
->GetMutable<lite::Tensor>(); front_is_scale = output_nodes.front()->arg()->is_weight;
float scale_value = scale_tensor->data<float>()[0] / range; Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
for (auto* quantized_node_ptr : output_act_node->outlinks) { // relink nodes and save value to quantized_node
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>( int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
"bit_length", bit_length); int range = ((1 << (bit_length - 1)) - 1);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>( auto* scope = quant_node->stmt()->op()->scope();
"input_scale", scale_value); auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr) ->GetMutable<lite::Tensor>();
RemoveDirectedLink(output_act_node, quantized_node_ptr); float scale_value = scale_tensor->data<float>()[0] / range;
}
// delete nodes and edges auto outlinks = output_act_node->outlinks;
std::unordered_set<const Node*> nodes2rm = { for (auto* quantized_node_ptr : outlinks) {
input_scale_node, &cur_node, output_scale_node, output_act_node}; quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>("bit_length",
GraphSafeRemoveNodes(graph.get(), nodes2rm); bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
} }
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
} }
// fuse quantized node and dequant node // fuse quantized node and dequant node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册