未验证 提交 2fb65e44 编写于 作者: W Wangzheee 提交者: GitHub

fix new quant (#45155)

上级 ab583173
...@@ -31,8 +31,7 @@ namespace ir { ...@@ -31,8 +31,7 @@ namespace ir {
GET_IR_NODE(quantize_linear_op); \ GET_IR_NODE(quantize_linear_op); \
GET_IR_NODE(quantize_linear_op_out); \ GET_IR_NODE(quantize_linear_op_out); \
GET_IR_NODE(dequantize_linear_op); \ GET_IR_NODE(dequantize_linear_op); \
GET_IR_NODE(dequantize_linear_op_out); \ GET_IR_NODE(dequantize_linear_op_out);
GET_IR_NODE(any_op2);
DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear")) AddOpCompat(OpCompat("quantize_linear"))
...@@ -127,21 +126,24 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -127,21 +126,24 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range; float input_scale = input_scale_data[0] / range;
auto* any_op2_desc = any_op2->Op(); int nums_any_ops = dequantize_linear_op_out->outputs.size();
any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), for (int i = 0; i < nums_any_ops; ++i) {
input_scale); auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x,
dequantize_linear_op_out->outputs[i]);
}
nodes2rm.insert(quantize_linear_op_scale); nodes2rm.insert(quantize_linear_op_scale);
nodes2rm.insert(quantize_linear_op); nodes2rm.insert(quantize_linear_op);
nodes2rm.insert(quantize_linear_op_out); nodes2rm.insert(quantize_linear_op_out);
nodes2rm.insert(dequantize_linear_op); nodes2rm.insert(dequantize_linear_op);
nodes2rm.insert(dequantize_linear_op_out); nodes2rm.insert(dequantize_linear_op_out);
// link x to any_op2
any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
found_count++; found_count++;
}; };
......
...@@ -2992,16 +2992,14 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() { ...@@ -2992,16 +2992,14 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto dequantize_linear_op_out = auto dequantize_linear_op_out =
pattern->NewNode(dequantize_linear_op_out_repr()) pattern->NewNode(dequantize_linear_op_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y"); ->assert_is_op_output("dequantize_linear", "Y")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quantize_linear_op quantize_linear_op
->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale}) ->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale})
.LinksTo({quantize_linear_op_out}); .LinksTo({quantize_linear_op_out});
dequantize_linear_op->LinksFrom({quantize_linear_op_out}) dequantize_linear_op->LinksFrom({quantize_linear_op_out})
.LinksTo({dequantize_linear_op_out}); .LinksTo({dequantize_linear_op_out});
any_op2->LinksFrom({dequantize_linear_op_out});
} }
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
......
...@@ -1709,7 +1709,6 @@ struct DeleteQuantDequantLinearOpPattern : public PatternBase { ...@@ -1709,7 +1709,6 @@ struct DeleteQuantDequantLinearOpPattern : public PatternBase {
// PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node. // PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node.
// Todo: Wangzheee // Todo: Wangzheee
PATTERN_DECL_NODE(dequantize_linear_op_out); PATTERN_DECL_NODE(dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
}; };
// Reshape + Transpose + Matmul // Reshape + Transpose + Matmul
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册