From 17f2c0899f56ca97961fe0567b4ee160670ea5b4 Mon Sep 17 00:00:00 2001
From: "joanna.wozna.intel" <joanna.wozna@intel.com>
Date: Thu, 6 Feb 2020 04:40:20 +0100
Subject: [PATCH] Add dequant-scale squash (#22409)

* Add dequant scale squash

test=develop

* Correct dequant-scale squash test

test=develop
---
 .../framework/ir/graph_pattern_detector.cc    | 19 +++++
 .../framework/ir/graph_pattern_detector.h     | 14 ++++
 .../ir/mkldnn/cpu_quantize_squash_pass.cc     | 44 +++++++++++
 .../ir/mkldnn/cpu_quantize_squash_pass.h      |  5 ++
 .../mkldnn/cpu_quantize_squash_pass_tester.cc | 75 ++++++++++++++++---
 5 files changed, 145 insertions(+), 12 deletions(-)

diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index 0b4b18c94b4..919364541e4 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -1522,6 +1522,25 @@ PDNode *patterns::FcDequant::operator()() {
   return dequant_out;
 }
 
+PDNode *patterns::DequantScale::operator()() {
+  // Create Operators
+  auto dequant_op =
+      pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
+  auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
+
+  auto dequant_out = pattern->NewNode(dequant_out_repr())
+                         ->AsOutput()
+                         ->assert_is_op_output("dequantize", "Output");
+  auto scale_out = pattern->NewNode(scale_out_repr())
+                       ->AsOutput()
+                       ->assert_is_op_output("scale", "Out");
+
+  dequant_op->LinksTo({dequant_out});
+  scale_op->LinksFrom({dequant_out}).LinksTo({scale_out});
+
+  return scale_out;
+}
+
 PDNode *patterns::PriorBox::operator()() {
   auto prior_box_op =
       pattern->NewNode(prior_box_op_repr())->assert_is_op("prior_box");
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index db58c9e8fdd..dcdf4318c88 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -929,6 +929,20 @@ struct FcDequant : public PatternBase {
   PATTERN_DECL_NODE(dequant_out);
 };
 
+// Dequantize + Scale
+struct DequantScale : public PatternBase {
+  DequantScale(PDPattern* pattern, const std::string& name_scope)
+      : PatternBase(pattern, name_scope, "dequant_scale") {}
+
+  PDNode* operator()();
+
+  PATTERN_DECL_NODE(dequant_op);
+  PATTERN_DECL_NODE(dequant_out);
+
+  PATTERN_DECL_NODE(scale_op);
+  PATTERN_DECL_NODE(scale_out);
+};
+
 // PriorBox operator
 // operator: prior_box_op
 // inputs: prior_box_input, prior_box_image
diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
index eff9b294f70..66556c7cc86 100644
--- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
+++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
@@ -284,6 +284,49 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
   PrettyLogDetail("---    squashed %d quantize op", removed_quantize);
 }
 
+// squash scale with dequant
+void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
+  GraphPatternDetector gpd;
+  patterns::DequantScale dequant_scale_pattern{gpd.mutable_pattern(),
+                                               "dequant_scale"};
+  dequant_scale_pattern();
+
+  int found_dequant_scale_squash_count = 0;
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                     Graph* g) {
+    VLOG(4) << "squash dequant-scale ops pair";
+
+    GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, dequant_scale_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, dequant_scale_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, dequant_scale_pattern);
+    GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, dequant_scale_pattern);
+
+    if (dequant_out->outputs.size() == 1 &&
+        scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
+      auto dequant_scale = dequant_op->Op()->GetAttrIfExists<float>("Scale");
+      auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
+
+      PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
+                        platform::errors::InvalidArgument(
+                            "Dequantize scale should have positive value"));
+      PADDLE_ENFORCE_GT(scale_scale, 0.0f,
+                        platform::errors::InvalidArgument(
+                            "Scale of scale op should have positive value"));
+
+      dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
+      dequant_op->Op()->SetOutput(
+          "Output", std::vector<std::string>({scale_out->Name()}));
+      IR_NODE_LINK_TO(dequant_op, scale_out);
+      GraphSafeRemoveNodes(graph, {dequant_out, scale_op});
+      found_dequant_scale_squash_count++;
+    }
+  };
+  gpd(graph, handler);
+  AddStatis(found_dequant_scale_squash_count);
+  PrettyLogDetail("---    squashed %d scale with dequant",
+                  found_dequant_scale_squash_count);
+}
+
 void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
   PADDLE_ENFORCE_NOT_NULL(
       graph,
@@ -298,6 +341,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
   ConvDequantSquash(graph);
   FcDequantSquash(graph);
   MultipleQuantizeSquash(graph);
+  DequantScaleSquash(graph);
 }
 
 }  // namespace ir
diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
index af8a66c929b..41c5323ba5c 100644
--- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
+++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
@@ -70,6 +70,11 @@ class CPUQuantizeSquashPass : public FusePassBase {
   */
   void MultipleQuantizeSquash(Graph* graph) const;
 
+  /*
+  *  Squash scale if dequantize is before scale
+  */
+  void DequantScaleSquash(Graph* graph) const;
+
   const std::string name_scope_{"squash"};
 };
 
diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
index 5a364aab1c5..1ce7fc9a72b 100644
--- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
+++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
@@ -24,7 +24,7 @@ namespace ir {
 void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
            const std::vector<std::string>& inputs,
            const std::vector<std::string>& outputs, bool use_mkldnn,
-           float scale = 0) {
+           float scale = 0, float bias = 0.0) {
   auto* op = prog->MutableBlock(0)->AppendOp();
   op->SetType(type);
   op->SetAttr("use_mkldnn", use_mkldnn);
@@ -59,6 +59,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
                           inputs.size()));
     op->SetInput("W", {inputs[1]});
     op->SetOutput("Out", outputs);
+  } else if (type == "scale") {
+    op->SetInput("X", {inputs[0]});
+    op->SetOutput("Out", {outputs[0]});
+    op->SetAttr("scale", scale);
+    op->SetAttr("bias", bias);
   }
 }
 
@@ -252,6 +257,21 @@ ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
   return prog;
 }
 
+// a->Dequant->b
+// b->Scale->c
+ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
+                                         float scale_scale, float bias) {
+  ProgramDesc prog;
+  for (auto& v : variable_names) {
+    prog.MutableBlock(0)->Var(v);
+  }
+  SetOp(&prog, "dequantize", "Dequant", {"a"}, {"b"}, use_mkldnn,
+        dequant_scale);
+  SetOp(&prog, "scale", "Scale", {"b"}, {"c"}, use_mkldnn, scale_scale, bias);
+
+  return prog;
+}
+
 void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
                       const char* var_name) {
   auto x = scope->Var(var_name);
@@ -289,17 +309,17 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
 }
 
 // check op->scale_out
-void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
-                       float scale) {
+void EqualScaleTest(const ProgramDesc& prog, const std::string& op_name,
+                    const std::string& scale_name, float scale) {
   std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
   PrepareGraph(&graph, prog);
   RegisterPass(&graph);
 
   for (auto* node : graph->Nodes()) {
     if (node->IsOp() &&
-        boost::get<std::string>(node->Op()->GetAttr("name")) == name) {
-      float scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
-      EXPECT_EQ(scale_out, scale);
+        boost::get<std::string>(node->Op()->GetAttr("name")) == op_name) {
+      float op_scale = boost::get<float>(node->Op()->GetAttr(scale_name));
+      EXPECT_EQ(op_scale, scale);
     }
   }
 }
@@ -368,9 +388,9 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
       BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
       remove_nodes);
 
-  EqualScaleOutTest(
+  EqualScaleTest(
       BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
-      "Conv1", scale2);
+      "Conv1", "Scale_out", scale2);
 }
 
 //  a->Conv1->b->Requant->c
@@ -388,12 +408,12 @@ TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
       remove_nodes);
 
   // check equal scale conv->scale_out and requant->scale_out
-  EqualScaleOutTest(
+  EqualScaleTest(
       BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
-      "Conv1", scale);
-  EqualScaleOutTest(
+      "Conv1", "Scale_out", scale);
+  EqualScaleTest(
       BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
-      "Conv2", scale);
+      "Conv2", "Scale_out", scale);
 }
 
 // from
@@ -544,6 +564,37 @@ TEST(CpuQuantizeSquashPass, quatize_with_different_scale) {
       remove_nodes);
 }
 
+// if scale has no bias
+TEST(CpuQuantizeSquashPass, dequantize_scale_with_no_bias) {
+  auto dequant_scale = 1.2345f;
+  auto scale_scale = 1.5432f;
+  auto bias = 0.0f;
+  auto use_mkldnn = true;
+  // remove: dequant out, scale op
+  auto remove_nodes = 2;
+  CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
+                                             scale_scale, bias),
+                remove_nodes);
+  EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
+                                              scale_scale, bias),
+                 "Dequant", "Scale", dequant_scale / scale_scale);
+}
+
+// if scale has bias
+TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
+  auto dequant_scale = 1.2345f;
+  auto scale_scale = 1.5432f;
+  auto bias = 1.0f;
+  auto use_mkldnn = true;
+  // nothing change
+  auto remove_nodes = 0;
+  CountNodeTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
+                                             scale_scale, bias),
+                remove_nodes);
+  EqualScaleTest(BuildDequantScaleProgramDesc(use_mkldnn, dequant_scale,
+                                              scale_scale, bias),
+                 "Dequant", "Scale", dequant_scale);
+}
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
-- 
GitLab