diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index fc0a30bed0efb4a12b5bab0e597b3093f8cb074f..ee7a2a722331e062a2ac55d213043fa5ec489f39 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -31,8 +31,7 @@ namespace ir { GET_IR_NODE(quantize_linear_op); \ GET_IR_NODE(quantize_linear_op_out); \ GET_IR_NODE(dequantize_linear_op); \ - GET_IR_NODE(dequantize_linear_op_out); \ - GET_IR_NODE(any_op2); + GET_IR_NODE(dequantize_linear_op_out); DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { AddOpCompat(OpCompat("quantize_linear")) @@ -127,21 +126,24 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { const float* input_scale_data = input_scale_tensor.data(); float input_scale = input_scale_data[0] / range; - auto* any_op2_desc = any_op2->Op(); - any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), - input_scale); + int nums_any_ops = dequantize_linear_op_out->outputs.size(); + for (int i = 0; i < nums_any_ops; ++i) { + auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); + any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), + input_scale); + // link x to any_op2 + any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), + quantize_linear_op_x->Var()->Name()); + any_op_desc->Flush(); + IR_NODE_LINK_TO(quantize_linear_op_x, + dequantize_linear_op_out->outputs[i]); + } nodes2rm.insert(quantize_linear_op_scale); nodes2rm.insert(quantize_linear_op); nodes2rm.insert(quantize_linear_op_out); nodes2rm.insert(dequantize_linear_op); nodes2rm.insert(dequantize_linear_op_out); - - // link x to any_op2 - any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), - quantize_linear_op_x->Var()->Name()); - any_op2_desc->Flush(); - IR_NODE_LINK_TO(quantize_linear_op_x, any_op2); GraphSafeRemoveNodes(graph, nodes2rm); found_count++; }; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 85b3bdb874d4f5ebf15d10d9998e18ed90f6945d..5f8dcf9b7e5d892c103ac7a5ffb4a43dc45d00d9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2992,16 +2992,14 @@ void patterns::DeleteQuantDequantLinearOpPattern::operator()() { auto dequantize_linear_op_out = pattern->NewNode(dequantize_linear_op_out_repr()) ->AsIntermediate() - ->assert_is_op_output("dequantize_linear", "Y"); - - auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + ->assert_is_op_output("dequantize_linear", "Y") + ->AsOutput(); quantize_linear_op ->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale}) .LinksTo({quantize_linear_op_out}); dequantize_linear_op->LinksFrom({quantize_linear_op_out}) .LinksTo({dequantize_linear_op_out}); - any_op2->LinksFrom({dequantize_linear_op_out}); } PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f0f7282683b710519256b0d10c94628dcf6d676d..9e2eb21b7f42dbc43caeb201e094d0ef70b08d80 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1709,7 +1709,6 @@ struct DeleteQuantDequantLinearOpPattern : public PatternBase { // PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node. // Todo: Wangzheee PATTERN_DECL_NODE(dequantize_linear_op_out); - PATTERN_DECL_NODE(any_op2); }; // Reshape + Transpose + Matmul