diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2814020b94f2ce2e0487a09886d8d99656706c06..5dcdc751c820ea573a5b3e2fbebfdc41c9b3c23b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1529,39 +1529,24 @@ PDNode *patterns::RequantOp::operator()() { return any_op; } -PDNode *patterns::ConvDequant::operator()() { - // Create Operators - auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); - auto dequant_op = - pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); - - auto conv_out = pattern->NewNode(conv_out_repr()) - ->assert_is_op_output("conv2d", "Output"); - auto dequant_out = pattern->NewNode(dequant_out_repr()) - ->AsOutput() - ->assert_is_op_output("dequantize", "Output"); - - conv_op->LinksTo({conv_out}); - dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out}); - - return dequant_out; -} - -PDNode *patterns::FcDequant::operator()() { - // Create Operators - auto fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc"); +PDNode *patterns::OpDequant::operator()() { + auto any_op = pattern->NewNode(any_op_repr()) + ->assert_is_op() + ->assert_more([&](Node *node) { + return (node->Op()->Type() == "matmul" || + node->Op()->Type() == "conv2d" || + node->Op()->Type() == "fc"); + }); + auto dequant_in = pattern->NewNode(dequant_in_repr()) + ->assert_is_op_input("dequantize", "Input"); auto dequant_op = pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); - - auto fc_out = - pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out"); auto dequant_out = pattern->NewNode(dequant_out_repr()) ->AsOutput() ->assert_is_op_output("dequantize", "Output"); - fc_op->LinksTo({fc_out}); - dequant_op->LinksFrom({fc_out}).LinksTo({dequant_out}); - + any_op->LinksTo({dequant_in}); + dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); return dequant_out; } @@ -1584,23 +1569,6 @@ PDNode *patterns::DequantScale::operator()() { return scale_out; } -PDNode *patterns::MatmulDequant::operator()() { - auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); - auto dequant_op = - pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); - - auto matmul_out = pattern->NewNode(matmul_out_repr()) - ->AsOutput() - ->assert_is_op_output("matmul", "Out"); - auto dequant_out = pattern->NewNode(dequant_out_repr()) - ->AsOutput() - ->assert_is_op_output("dequantize", "Output"); - - matmul_op->LinksTo({matmul_out}); - dequant_op->LinksFrom({matmul_out}).LinksTo({dequant_out}); - return dequant_out; -} - PDNode *patterns::ScaleMatmul::operator()() { auto scale_in = pattern->NewNode(scale_in_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 58248aec709d5e24d804f6bc19aa36ca392aad7f..91a086d090a7434ff071fe345d82f3ccb5d1b36e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -929,33 +929,18 @@ struct RequantOp : public PatternBase { PATTERN_DECL_NODE(requant_out); }; -// Conv + Dequant +// Op + Dequant // named nodes: -// conv_op, conv_out +// any_op, dequant_in // dequant_op, dequant_out -struct ConvDequant : public PatternBase { - ConvDequant(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_dequant") {} - - PDNode* operator()(); - - PATTERN_DECL_NODE(conv_op); - PATTERN_DECL_NODE(conv_out); - - PATTERN_DECL_NODE(dequant_op); - PATTERN_DECL_NODE(dequant_out); -}; - -// Fc + Dequant -struct FcDequant : public PatternBase { - FcDequant(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "fc_dequant") {} +struct OpDequant : public PatternBase { + OpDequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "op_dequant") {} PDNode* operator()(); - PATTERN_DECL_NODE(fc_op); - PATTERN_DECL_NODE(fc_out); - + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(dequant_in); PATTERN_DECL_NODE(dequant_op); PATTERN_DECL_NODE(dequant_out); }; @@ -974,20 +959,6 @@ struct DequantScale : public PatternBase { PATTERN_DECL_NODE(scale_out); }; -// Matmul + Dequantize -struct MatmulDequant : public PatternBase { - MatmulDequant(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "matmul_dequant") {} - - PDNode* operator()(); - - PATTERN_DECL_NODE(matmul_op); - PATTERN_DECL_NODE(matmul_out); - - PATTERN_DECL_NODE(dequant_op); - PATTERN_DECL_NODE(dequant_out); -}; - // Scale + Matmul struct ScaleMatmul : public PatternBase { ScaleMatmul(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index ed57e329443ea6dcf75d70a82361955928dbe559..6283b65be9667d5069e47de87c42bf507bc459c5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -223,71 +223,44 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { found_requant_squash_count); } -void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { - GraphPatternDetector gpd; - patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(), - "conv_dequant"}; - conv_dequant_pattern(); - - int found_conv_dequant_squash_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "squash conv-dequant ops pair"; - - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern); - - // if conv2d has one output - // and there is no fuse residual connection - // because residual fusion does not support force output with fp32 - if (conv_out->outputs.size() == 1 && - !(conv_op->Op()->GetAttrIfExists("fuse_residual_connection"))) { - conv_op->Op()->SetAttr("force_fp32_output", true); - conv_op->Op()->SetOutput("Output", - std::vector({dequant_out->Name()})); - IR_NODE_LINK_TO(conv_op, dequant_out); - GraphSafeRemoveNodes(graph, {conv_out, dequant_op}); - found_conv_dequant_squash_count++; - } - }; - gpd(graph, handler); - AddStatis(found_conv_dequant_squash_count); - PrettyLogDetail("--- squashed %d dequant with convs", - found_conv_dequant_squash_count); -} - -// squash fc with dequant -void CPUQuantizeSquashPass::FcDequantSquash(Graph* graph) const { +// squash dequant with previous op if that op has force_fp32_output attr +// conv2d, fc, matmul +void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { GraphPatternDetector gpd; - patterns::FcDequant fc_dequant_pattern{gpd.mutable_pattern(), "fc_dequant"}; - fc_dequant_pattern(); + patterns::OpDequant op_dequant_pattern{gpd.mutable_pattern(), "op_dequant"}; + op_dequant_pattern(); - int found_fc_dequant_squash_count = 0; + int found_op_dequant_squash_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "squash fc-dequant ops pair"; - - GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, fc_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, fc_dequant_pattern); - - // if fc has force_fp32_output attribute - if (fc_out->outputs.size() == 1) { - fc_op->Op()->SetAttr("force_fp32_output", true); - fc_op->Op()->SetOutput("Out", - std::vector({dequant_out->Name()})); - IR_NODE_LINK_TO(fc_op, dequant_out); - GraphSafeRemoveNodes(graph, {fc_out, dequant_op}); - found_fc_dequant_squash_count++; + VLOG(4) << "squash op-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, op_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern); + + if (dequant_in->outputs.size() == 1) { + auto output_name = "Out"; + if (any_op->Op()->Type() == "conv2d") { + // do not squash if fuse residual connection is true + // because residual fusion does not support force output with fp32 + if (any_op->Op()->GetAttrIfExists("fuse_residual_connection")) + return; + output_name = "Output"; + } + any_op->Op()->SetAttr("force_fp32_output", true); + any_op->Op()->SetOutput(output_name, + std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(any_op, dequant_out); + GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); + found_op_dequant_squash_count++; } }; gpd(graph, handler); - AddStatis(found_fc_dequant_squash_count); - PrettyLogDetail("--- squashed %d dequant with fcs", - found_fc_dequant_squash_count); + AddStatis(found_op_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with ops", + found_op_dequant_squash_count); } void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { @@ -389,38 +362,6 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { found_dequant_scale_squash_count); } -// squash dequant with dequant -void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const { - GraphPatternDetector gpd; - patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(), - "matmul_dequant"}; - matmul_dequant_pattern(); - - int found_matmul_dequant_squash_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "squash matmul-dequant ops pair"; - - GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern); - - if (matmul_out->outputs.size() == 1) { - matmul_op->Op()->SetAttr("force_fp32_output", true); - matmul_op->Op()->SetOutput( - "Out", std::vector({dequant_out->Name()})); - IR_NODE_LINK_TO(matmul_op, dequant_out); - GraphSafeRemoveNodes(graph, {matmul_out, dequant_op}); - found_matmul_dequant_squash_count++; - } - }; - gpd(graph, handler); - AddStatis(found_matmul_dequant_squash_count); - PrettyLogDetail("--- squashed %d dequant with matmul", - found_matmul_dequant_squash_count); -} - void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, @@ -433,11 +374,9 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { DequantQuantSquash(graph, &nodes_keep_counter); OpRequantSquash(graph); RequantOpSquash(graph); - ConvDequantSquash(graph); - FcDequantSquash(graph); + OpDequantSquash(graph); MultipleQuantizeSquash(graph); DequantScaleSquash(graph); - MatmulDequantSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 8ce5c1858559c6083df17d66630931818858d6f8..98a518e4e532bb250459448e864a4fb89d55686f 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -61,29 +61,19 @@ class CPUQuantizeSquashPass : public FusePassBase { void RequantOpSquash(Graph* graph) const; /* - * Squash conv2d with dequant when dequant is the only op after conv2d - */ - void ConvDequantSquash(Graph* graph) const; - - /* - * Squash fc with dequant when dequant is the next op after fc - */ - void FcDequantSquash(Graph* graph) const; + * Squash dequant if the previous operator has force_fp32_output attribute + */ + void OpDequantSquash(Graph* graph) const; /* - * Squash quantize if several quatize ops have the same scale - */ + * Squash quantize if several quatize ops have the same scale + */ void MultipleQuantizeSquash(Graph* graph) const; /* - * Squash scale if dequantize is before scale - */ - void DequantScaleSquash(Graph* graph) const; - - /* - * Squash dequantize if it is after matmul + * Squash scale if dequantize is before scale */ - void MatmulDequantSquash(Graph* graph) const; + void DequantScaleSquash(Graph* graph) const; const std::string name_scope_{"squash"}; };