diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e82de1b13f3f67495ef2c39a9b5eec42a44833cf..32aa6bf7ceb5ee5539cec74a5125d43b07c52b19 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1490,20 +1490,22 @@ PDNode *patterns::ConvConcatReLU::operator()() { return relu_out; } -PDNode *patterns::ConvRequant::operator()() { - // Create Operators - auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); +PDNode *patterns::OpRequant::operator()() { + auto any_op = pattern->NewNode(any_op_repr()) + ->assert_is_op() + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("Scale_out") ? true : false; + }); + auto requant_in = pattern->NewNode(requant_in_repr()) + ->assert_is_op_input("requantize", "Input"); auto requant_op = pattern->NewNode(requant_op_repr())->assert_is_op("requantize"); - auto conv_out = pattern->NewNode(conv_out_repr()) - ->assert_is_op_output("conv2d", "Output"); auto requant_out = pattern->NewNode(requant_out_repr()) ->AsOutput() ->assert_is_op_output("requantize", "Output"); - conv_op->LinksTo({conv_out}); - requant_op->LinksFrom({conv_out}).LinksTo({requant_out}); - + any_op->LinksTo({requant_in}); + requant_op->LinksFrom({requant_in}).LinksTo({requant_out}); return requant_out; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index e7e912e54ab948ced59465e49a8386e05d479ad0..5444c143bf3874098d9209c83e13f8301faf9735 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -897,19 +897,18 @@ struct ConvConcatReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; -// Conv + Requant +// Op + Requant // named nodes: -// conv_op, conv_out +// any_op, any_out // requant_op, requant_out -struct ConvRequant : public PatternBase { - ConvRequant(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_requant") {} +struct OpRequant : public PatternBase { + OpRequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "op_requant") {} PDNode* operator()(); - PATTERN_DECL_NODE(conv_op); - PATTERN_DECL_NODE(conv_out); - + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(requant_in); PATTERN_DECL_NODE(requant_op); PATTERN_DECL_NODE(requant_out); }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index f8f1a2ddd5b27da649e838efce2de0449dc92b99..92396306934a682f4316e6fb6ed4ca4ba07ae0e7 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -126,38 +126,47 @@ void CPUQuantizeSquashPass::DequantQuantSquash( found_dequant_quant_count); } -void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const { +// op+requant squash if op has Scale_out attr +// conv2d and fc +void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const { GraphPatternDetector gpd; - patterns::ConvRequant conv_requant_pattern{gpd.mutable_pattern(), - "conv_requant"}; - conv_requant_pattern(); + patterns::OpRequant op_requant_pattern{gpd.mutable_pattern(), "op_requant"}; + op_requant_pattern(); int found_requant_squash_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "squash conv-requantize ops pair"; + VLOG(4) << "squash op-requantize ops pair"; - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_requant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_requant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, conv_requant_pattern); - GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, conv_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(any_op, any_op, op_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_in, requant_in, op_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_op, requant_op, op_requant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(requant_out, requant_out, op_requant_pattern); + + if (requant_in->outputs.size() == 1) { + std::string any_op_output_name; + for (auto name : any_op->Op()->OutputNames()) + for (auto output_name : any_op->Op()->Output(name)) + if (output_name == requant_in->Name()) any_op_output_name = name; + + PADDLE_ENFORCE_NE( + any_op_output_name.empty(), true, + platform::errors::NotFound("Operator before requantize operator " + "should has requantize input as output")); - // if conv2d has one output squash - if (conv_out->outputs.size() == 1) { float requant_scale_out = boost::get(requant_op->Op()->GetAttr("Scale_out")); - conv_op->Op()->SetAttr("Scale_out", requant_scale_out); - conv_op->Op()->SetOutput("Output", - std::vector({requant_out->Name()})); - IR_NODE_LINK_TO(conv_op, requant_out); - GraphSafeRemoveNodes(graph, {conv_out, requant_op}); - + any_op->Op()->SetAttr("Scale_out", requant_scale_out); + any_op->Op()->SetOutput(any_op_output_name, + std::vector({requant_out->Name()})); + IR_NODE_LINK_TO(any_op, requant_out); + GraphSafeRemoveNodes(graph, {requant_in, requant_op}); found_requant_squash_count++; } }; gpd(graph, handler); AddStatis(found_requant_squash_count); - PrettyLogDetail("--- squashed %d requantize with convs", + PrettyLogDetail("--- squashed %d requantize ops", found_requant_squash_count); } @@ -369,7 +378,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { std::unordered_map nodes_keep_counter; FindNodesToKeep(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter); - ConvRequantSquash(graph); + OpRequantSquash(graph); ConvDequantSquash(graph); FcDequantSquash(graph); MultipleQuantizeSquash(graph); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 475c0591f3656efd53074ed037138fc745f86b96..be770fb0b2b569adefe6f3062dc7a907eaeac716 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -53,7 +53,7 @@ class CPUQuantizeSquashPass : public FusePassBase { /* * Squash requantize op into conv with scale_out like requantize scale_out */ - void ConvRequantSquash(Graph* graph) const; + void OpRequantSquash(Graph* graph) const; /* * Squash conv2d with dequant when dequant is the only op after conv2d diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 6adf1fcaa51cc46d957779c7231e57b6a7cd495b..163497d80c42c16d91b0100b036051047dc51306 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -59,6 +59,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, inputs.size())); op->SetInput("W", {inputs[1]}); op->SetOutput("Out", outputs); + op->SetAttr("Scale_out", scale); } else if (type == "scale") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); @@ -68,6 +69,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); op->SetOutput("Out", {outputs[0]}); + op->SetAttr("Scale_out", scale); } } @@ -96,7 +98,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, } static const std::initializer_list variable_names{ - "a", "b", "c", "d", "e", "f", "g", "h", "x", "y"}; + "a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"}; // a->Conv1->b // b->Dequant(scale1)->c @@ -125,23 +127,30 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out, return prog; } -// a->Conv1->b->Requant(scale1)->c -// d->Conv2->e->Requant(scale2)->f -// {c,f}->Concat -ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out, - float scale1, float scale2) { +// a->Conv->b->Requant(scale1)->c +// d->Fc->e->Requant(scale2)->f +// {x,y}->Matmul->g->Requant(scale3)->h +// {c,f,h}->Concat +ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale, + float fc_scale, float matmul_scale, + float requant_scale1, + float requant_scale2, + float requant_scale3) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); - SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1); - - SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out); - SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2); - - SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale); + SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, + requant_scale1); + SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale); + SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, + requant_scale2); + SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale); + SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn, + requant_scale3); + SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn); return prog; } @@ -412,27 +421,28 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { "Conv1", "Scale_out", scale2); } -// a->Conv1->b->Requant->c -// d->Conv2->e->Requant->f -// {c,f}->Concat -TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) { - // Delete both requantize op - auto scale_out = 1.0f; - auto scale = 1.2345f; +// a->Conv->b->Requant->c +// d->Fc->e->Requant->f +// {x,y}->Matmul->g->Requant->h +// {c,f,h}->Concat +TEST(CpuQuantizeSquashPass, op_requantize_squash) { + // Delete all requantize op + auto conv_scale = 0.234f; + auto fc_scale = 1.234f; + auto matmul_scale = 2.234f; + auto requant_scale1 = 2.234f; + auto requant_scale2 = 3.234f; + auto requant_scale3 = 4.234f; auto use_mkldnn = true; - // Remove 4 nodes: b, Requant1, e, Requant2 - auto remove_nodes = 4; - CountNodeTest( - BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), - remove_nodes); - - // check equal scale conv->scale_out and requant->scale_out - EqualScaleTest( - BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), - "Conv1", "Scale_out", scale); - EqualScaleTest( - BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale), - "Conv2", "Scale_out", scale); + // Remove 4 nodes: b, Requant1, e, Requant2, g, Requant3 + auto remove_nodes = 6; + auto program_desc = + BuildOpRequantProgramDesc(use_mkldnn, conv_scale, fc_scale, matmul_scale, + requant_scale1, requant_scale2, requant_scale3); + CountNodeTest(program_desc, remove_nodes); + EqualScaleTest(program_desc, "Conv", "Scale_out", requant_scale1); + EqualScaleTest(program_desc, "Fc", "Scale_out", requant_scale2); + EqualScaleTest(program_desc, "Matmul", "Scale_out", requant_scale3); } // from