From 2fb65e44f3cf430ebf3e714f7a9f8d0c80f6c66d Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 16 Aug 2022 14:04:15 +0800 Subject: [PATCH] fix new quant (#45155) --- .../ir/delete_quant_dequant_linear_op_pass.cc | 24 ++++++++++--------- .../framework/ir/graph_pattern_detector.cc | 6 ++--- .../framework/ir/graph_pattern_detector.h | 1 - 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index fc0a30bed0e..ee7a2a72233 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -31,8 +31,7 @@ namespace ir { GET_IR_NODE(quantize_linear_op); \ GET_IR_NODE(quantize_linear_op_out); \ GET_IR_NODE(dequantize_linear_op); \ - GET_IR_NODE(dequantize_linear_op_out); \ - GET_IR_NODE(any_op2); + GET_IR_NODE(dequantize_linear_op_out); DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { AddOpCompat(OpCompat("quantize_linear")) @@ -127,21 +126,24 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { const float* input_scale_data = input_scale_tensor.data(); float input_scale = input_scale_data[0] / range; - auto* any_op2_desc = any_op2->Op(); - any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), - input_scale); + int nums_any_ops = dequantize_linear_op_out->outputs.size(); + for (int i = 0; i < nums_any_ops; ++i) { + auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); + any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), + input_scale); + // link x to any_op2 + any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), + quantize_linear_op_x->Var()->Name()); + any_op_desc->Flush(); + IR_NODE_LINK_TO(quantize_linear_op_x, + dequantize_linear_op_out->outputs[i]); + } nodes2rm.insert(quantize_linear_op_scale); nodes2rm.insert(quantize_linear_op); nodes2rm.insert(quantize_linear_op_out); nodes2rm.insert(dequantize_linear_op); nodes2rm.insert(dequantize_linear_op_out); - - // link x to any_op2 - any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), - quantize_linear_op_x->Var()->Name()); - any_op2_desc->Flush(); - IR_NODE_LINK_TO(quantize_linear_op_x, any_op2); GraphSafeRemoveNodes(graph, nodes2rm); found_count++; }; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 85b3bdb874d..5f8dcf9b7e5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2992,16 +2992,14 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() { auto dequantize_linear_op_out = pattern->NewNode(dequantize_linear_op_out_repr()) ->AsIntermediate() - ->assert_is_op_output("dequantize_linear", "Y"); - - auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + ->assert_is_op_output("dequantize_linear", "Y") + ->AsOutput(); quantize_linear_op ->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale}) .LinksTo({quantize_linear_op_out}); dequantize_linear_op->LinksFrom({quantize_linear_op_out}) .LinksTo({dequantize_linear_op_out}); - any_op2->LinksFrom({dequantize_linear_op_out}); } PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f0f7282683b..9e2eb21b7f4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1709,7 +1709,6 @@ struct DeleteQuantDequantLinearOpPattern : public PatternBase { // PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node. // Todo: Wangzheee PATTERN_DECL_NODE(dequantize_linear_op_out); - PATTERN_DECL_NODE(any_op2); }; // Reshape + Transpose + Matmul -- GitLab