未验证 提交 2fb65e44 编写于 作者: W Wangzheee 提交者: GitHub

fix new quant (#45155)

上级 ab583173
......@@ -31,8 +31,7 @@ namespace ir {
GET_IR_NODE(quantize_linear_op); \
GET_IR_NODE(quantize_linear_op_out); \
GET_IR_NODE(dequantize_linear_op); \
GET_IR_NODE(dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
GET_IR_NODE(dequantize_linear_op_out);
DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear"))
......@@ -127,21 +126,24 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range;
auto* any_op2_desc = any_op2->Op();
any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
int nums_any_ops = dequantize_linear_op_out->outputs.size();
for (int i = 0; i < nums_any_ops; ++i) {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x,
dequantize_linear_op_out->outputs[i]);
}
nodes2rm.insert(quantize_linear_op_scale);
nodes2rm.insert(quantize_linear_op);
nodes2rm.insert(quantize_linear_op_out);
nodes2rm.insert(dequantize_linear_op);
nodes2rm.insert(dequantize_linear_op_out);
// link x to any_op2
any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
......
......@@ -2992,16 +2992,14 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto dequantize_linear_op_out =
pattern->NewNode(dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
->assert_is_op_output("dequantize_linear", "Y")
->AsOutput();
quantize_linear_op
->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale})
.LinksTo({quantize_linear_op_out});
dequantize_linear_op->LinksFrom({quantize_linear_op_out})
.LinksTo({dequantize_linear_op_out});
any_op2->LinksFrom({dequantize_linear_op_out});
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
......
......@@ -1709,7 +1709,6 @@ struct DeleteQuantDequantLinearOpPattern : public PatternBase {
// PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node.
// Todo: Wangzheee
PATTERN_DECL_NODE(dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
// Reshape + Transpose + Matmul
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册