From 3cb5623dad945f35d18a36f54d0a913db0920153 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Wed, 8 Apr 2020 16:10:19 +0200 Subject: [PATCH] Add matmul dequant squash (#23505) test=develop --- .../framework/ir/graph_pattern_detector.cc | 17 ++++++++++ .../framework/ir/graph_pattern_detector.h | 14 ++++++++ .../ir/mkldnn/cpu_quantize_squash_pass.cc | 33 +++++++++++++++++++ .../ir/mkldnn/cpu_quantize_squash_pass.h | 5 +++ .../mkldnn/cpu_quantize_squash_pass_tester.cc | 32 +++++++++++++++++- 5 files changed, 100 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8822c1a4c9..e82de1b13f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1562,6 +1562,23 @@ PDNode *patterns::DequantScale::operator()() { return scale_out; } +PDNode *patterns::MatmulDequant::operator()() { + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto dequant_op = + pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize"); + + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + auto dequant_out = pattern->NewNode(dequant_out_repr()) + ->AsOutput() + ->assert_is_op_output("dequantize", "Output"); + + matmul_op->LinksTo({matmul_out}); + dequant_op->LinksFrom({matmul_out}).LinksTo({dequant_out}); + return dequant_out; +} + PDNode *patterns::PriorBox::operator()() { auto prior_box_op = pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f2a415814d..e7e912e54a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -959,6 +959,20 @@ struct DequantScale : public PatternBase { PATTERN_DECL_NODE(scale_out); }; +// Matmul + Dequantize +struct MatmulDequant : public PatternBase { + MatmulDequant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_dequant") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); + + PATTERN_DECL_NODE(dequant_op); + PATTERN_DECL_NODE(dequant_out); +}; + // PriorBox operator // operator: prior_box_op // inputs: prior_box_input, prior_box_image diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 66556c7cc8..f8f1a2ddd5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -327,6 +327,38 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { found_dequant_scale_squash_count); } +// squash dequant with dequant +void CPUQuantizeSquashPass::MatmulDequantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::MatmulDequant matmul_dequant_pattern{gpd.mutable_pattern(), + "matmul_dequant"}; + matmul_dequant_pattern(); + + int found_matmul_dequant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash matmul-dequant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, matmul_dequant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, matmul_dequant_pattern); + + if (matmul_out->outputs.size() == 1) { + matmul_op->Op()->SetAttr("force_fp32_output", true); + matmul_op->Op()->SetOutput( + "Out", std::vector({dequant_out->Name()})); + IR_NODE_LINK_TO(matmul_op, dequant_out); + GraphSafeRemoveNodes(graph, {matmul_out, dequant_op}); + found_matmul_dequant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_matmul_dequant_squash_count); + PrettyLogDetail("--- squashed %d dequant with matmul", + found_matmul_dequant_squash_count); +} + void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, @@ -342,6 +374,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { FcDequantSquash(graph); MultipleQuantizeSquash(graph); DequantScaleSquash(graph); + MatmulDequantSquash(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index 41c5323ba5..475c0591f3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -75,6 +75,11 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void DequantScaleSquash(Graph* graph) const; + /* + * Squash dequantize if it is after matmul + */ + void MatmulDequantSquash(Graph* graph) const; + const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 1ce7fc9a72..6adf1fcaa5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -64,6 +64,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetOutput("Out", {outputs[0]}); op->SetAttr("scale", scale); op->SetAttr("bias", bias); + } else if (type == "matmul") { + op->SetInput("X", {inputs[0]}); + op->SetInput("Y", {inputs[1]}); + op->SetOutput("Out", {outputs[0]}); } } @@ -92,7 +96,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, } static const std::initializer_list variable_names{ - "a", "b", "c", "d", "e", "f", "g", "h"}; + "a", "b", "c", "d", "e", "f", "g", "h", "x", "y"}; // a->Conv1->b // b->Dequant(scale1)->c @@ -272,6 +276,21 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, return prog; } +// {x,y}->Matmul->b +// b->Dequant->c +ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn, + float dequant_scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn); + SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, + dequant_scale); + + return prog; +} + void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, const char* var_name) { auto x = scope->Var(var_name); @@ -595,6 +614,17 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) { scale_scale, bias), "Dequant", "Scale", dequant_scale); } + +TEST(CpuQuantizeSquashPass, matmul_with_dequant) { + auto dequant_scale = 1.2345f; + auto use_mkldnn = true; + // remove: matmul_out, dequant_op + auto remove_nodes = 2; + CountNodeTest(BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), + remove_nodes); + IsForceFp32OutputTest( + BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true); +} } // namespace ir } // namespace framework } // namespace paddle -- GitLab