未验证 提交 9cc7dfa8 编写于 作者: J juncaipeng 提交者: GitHub

Fix quant dequant fuse pass (#2190)


* fix bug for accessing the removed node, test=develop
上级 4d530acc
......@@ -30,47 +30,52 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// releated nodes
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::vector<Node*> quant_nodes;
for (auto& cur_node : graph->mutable_nodes()) {
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
// find input nodes and output nodes
std::list<Node*> input_nodes = cur_node.inlinks;
std::list<Node*> output_nodes = cur_node.outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
quant_nodes.push_back(&cur_node);
}
}
for (auto quant_node : quant_nodes) {
// find input nodes and output nodes
std::list<Node*> input_nodes = quant_node->inlinks;
std::list<Node*> output_nodes = quant_node->outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
// relink nodes and save value to quantized_node
int bit_length = cur_node.stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = cur_node.stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
for (auto* quantized_node_ptr : output_act_node->outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>(
"bit_length", bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// relink nodes and save value to quantized_node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, &cur_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node_ptr : outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>("bit_length",
bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
// fuse quantized node and dequant node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册