diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 78e346bbdf0ae6e50ec926b9627ad6c9966b53c5..8bec1f08b090210550ffc41c0c24ece5f697fa27 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1267,6 +1267,24 @@ PDNode *patterns::ConvRequant::operator()() { return requant_out; } +PDNode *patterns::ConvDequant::operator()() { + // Create Operators + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); + auto dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + + auto conv_out = pattern->NewNode(conv_out_repr()) + ->assert_is_op_output("conv2d", "Output"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + conv_op->LinksTo({conv_out}); + dequant_op->LinksFrom({conv_out}).LinksTo({dequant_out}); + + return dequant_out; +} + PDNode *patterns::PriorBox::operator()() { auto prior_box_op = pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index dafe9a6cbf4bad4bba31886ef4da937dc37ecbd9..a99889f7cce6d08436cc9a4116786cfe37e92f2a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -793,6 +793,23 @@ struct ConvRequant : public PatternBase { PATTERN_DECL_NODE(requant_out); }; +// Conv + Dequant +// named nodes: +// conv_op, conv_out +// dequant_op, dequant_out +struct ConvDequant : public PatternBase { + ConvDequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_dequant") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(conv_op); + PATTERN_DECL_NODE(conv_out); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + // PriorBox operator // operator: prior_box_op // inputs: prior_box_input, prior_box_image diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 6277df14321b0b48184406ee7430618071184136..ac9ad7937a49ed249989e2bf36afba5305fdf451 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -160,6 +160,38 @@ void CPUQuantizeSquashPass::ConvRequantSquash(Graph* graph) const { found_requant_squash_count); } +void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::ConvDequant conv_dequant_pattern{gpd.mutable_pattern(), + "conv_dequant"}; + conv_dequant_pattern(); + + int found_conv_dequant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash conv-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, conv_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern); + + // if conv2d has one output + if (conv_out->outputs.size() == 1) { + conv_op->Op()->SetAttr("force_fp32_output", true); + conv_op->Op()->SetOutput("Output", + std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(conv_op, dequant_out); + GraphSafeRemoveNodes(graph, {conv_out, dequant_op}); + found_conv_dequant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_conv_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with convs", + found_conv_dequant_squash_count); +} + void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE(graph); FusePassBase::Init("cpu_quantize_squash_pass", graph); @@ -168,6 +200,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { FindNodesToKeep(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter); ConvRequantSquash(graph); + ConvDequantSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 52acdb0390a653d2d4f667488e05dfd4e3a8484c..7e9e92e3dacd7dc71ed4902133c7da00eb595faf 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -55,6 +55,11 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void ConvRequantSquash(Graph* graph) const; + /* + * Squash conv2d with dequant when dequant is the only op after conv2d + */ + void ConvDequantSquash(Graph* graph) const; + const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 2d7640097140f0a5131ef078520e816be3c76cdb..08b605a713b92e296069030a5c7c439433098b06 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, return prog; } +// a->Conv1->b +// b->Dequant1(Scale1)->c +// c->Concat +ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out, + float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn); + return prog; +} + +// a->Conv1->b +// b->Dequant1(Scale1)->c +// b->Conv2->d +ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out, + float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out); + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale); + SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn); + return prog; +} + void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, const char* var_name) { auto x = scope->Var(var_name); @@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name, void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in, float scale_out) { std::unique_ptr graph(new ir::Graph(prog)); + PrepareGraph(&graph, prog); RegisterPass(&graph); @@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) { auto use_mkldnn = true; // Remove 4 nodes: Dequant, Quant, e, f auto remove_nodes = 4; + CountNodeTest( BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale), remove_nodes); @@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) { auto use_mkldnn = true; // Remove 4 nodes: Dequant, Quant, e, d auto remove_nodes = 4; + CountNodeTest( BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2), remove_nodes); @@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) { // Remove 3 nodes: Quant1, c, Quant2, // Insert 1 node: Requant auto remove_nodes = 2; + CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale, scale, scale2), remove_nodes); @@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass, // Remove 3 nodes: Dequant1, c, Quant // Insert 1 node: Requant auto remove_nodes = 2; + CountNodeTest( BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2), remove_nodes); @@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) { remove_nodes); } +// a->Conv1->c->Concat +TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto use_mkldnn = true; + // remove 2 nodes: Dequant1, c + auto remove_nodes = 2; + CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale), + remove_nodes); +} + +TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) { + auto scale_out = 1.0f; + auto scale = 1.2345f; + auto use_mkldnn = true; + // nothing change + auto remove_nodes = 0; + CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale), + remove_nodes); +} + } // namespace ir } // namespace framework } // namespace paddle