From 930ca3f4950187131aa010f3ae88006a0fbf91ad Mon Sep 17 00:00:00 2001
From: Wangzheee <634486483@qq.com>
Date: Fri, 18 Jun 2021 21:26:17 +0800
Subject: [PATCH] pass enhance (#33661)

---
 .../fluid/framework/ir/conv_bn_fuse_pass.cc   | 259 +++++++++++++++++-
 paddle/fluid/framework/ir/conv_bn_fuse_pass.h |   6 +
 .../fluid/framework/ir/pass_tester_helper.h   |  29 +-
 .../fluid/operators/compat/batch_norm.pbtxt   |   4 +
 paddle/fluid/operators/compat/conv2d.pbtxt    |   8 +
 paddle/fluid/operators/compat/relu.pbtxt      |   8 +
 6 files changed, 307 insertions(+), 7 deletions(-)

diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
index 9cc44c941ec..03a78ec3a21 100644
--- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
@@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope,
   }
 }
 
+ConvBNFusePass::ConvBNFusePass() {
+  AddOpCompat(OpCompat("conv2d"))
+      .AddInput("Input")
+      .IsTensor()
+      .End()
+      .AddInput("Filter")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsOptional()
+      .End()
+      .AddInput("ResidualData")
+      .IsOptional()
+      .End()
+      .AddOutput("Output")
+      .IsTensor()
+      .End()
+      .AddAttr("strides")
+      .End()
+      .AddAttr("paddings")
+      .End()
+      .AddAttr("padding_algorithm")
+      .IsOptional()
+      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
+      .End()
+      .AddAttr("groups")
+      .IsNumGE(1)
+      .End()
+      .AddAttr("dilations")
+      .End()
+      .AddAttr("data_format")
+      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
+      .End();
+
+  AddOpCompat(OpCompat("batch_norm"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddInput("Scale")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsTensor()
+      .End()
+      .AddInput("Mean")
+      .IsTensor()
+      .End()
+      .AddInput("Variance")
+      .IsTensor()
+      .End()
+      .AddOutput("MeanOut")
+      .IsTensor()
+      .End()
+      .AddOutput("VarianceOut")
+      .IsTensor()
+      .End()
+      .AddOutput("SavedMean")
+      .IsTensor()
+      .End()
+      .AddOutput("SavedVariance")
+      .IsTensor()
+      .End()
+      .AddOutput("Y")
+      .IsTensor()
+      .End()
+      .AddAttr("epsilon")
+      .IsNumLE(0.001f)
+      .IsNumGE(0.0f)
+      .End();
+
+  AddOpCompat(OpCompat("elementwise_add"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddInput("Y")
+      .IsTensor()
+      .End()
+      .AddOutput("Out")
+      .IsTensor()
+      .End()
+      .AddAttr("axis")
+      .IsNumEQ(1)
+      .End();
+}
+
 void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
   PADDLE_ENFORCE_NOT_NULL(
       graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
@@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
   int found_conv_bn_count = 0;
   auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                      Graph* g) {
+    if (!IsCompat(subgraph, g)) {
+      LOG(WARNING) << "Pass in op compat failed.";
+      return;
+    }
     VLOG(4) << "handle " + conv_type() + "BN fuse";
-
     // conv, batch_norm,
     // conv_weight, conv_out,
     // bn_scale, bn_bias, bn_mean, bn_variance,
@@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
       }
       conv->Op()->SetOutput("Output",
                             std::vector<std::string>({bn_out->Name()}));
+      if (!IsCompat(*conv->Op())) {
+        LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
+        return;
+      }
       GraphSafeRemoveNodes(
           graph,
           {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
@@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
       desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
       desc.SetType("elementwise_add");
       desc.SetAttr("axis", 1);
+      if (!IsCompat(desc)) {
+        LOG(WARNING)
+            << "conv_bn fuse pass in out elementwise_add op compat failed.";
+        return;
+      }
       auto eltwise_op = g->CreateOpNode(&desc);  // OpDesc will be copied.
 
       GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
@@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
   AddStatis(found_conv_bn_count);
 }
 
+ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
+  AddOpCompat(OpCompat("conv2d"))
+      .AddInput("Input")
+      .IsTensor()
+      .End()
+      .AddInput("Filter")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsOptional()
+      .End()
+      .AddInput("ResidualData")
+      .IsOptional()
+      .End()
+      .AddOutput("Output")
+      .IsTensor()
+      .End()
+      .AddAttr("strides")
+      .End()
+      .AddAttr("paddings")
+      .End()
+      .AddAttr("padding_algorithm")
+      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
+      .IsOptional()
+      .End()
+      .AddAttr("groups")
+      .IsNumGE(1)
+      .End()
+      .AddAttr("dilations")
+      .End()
+      .AddAttr("data_format")
+      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
+      .End();
+
+  AddOpCompat(OpCompat("batch_norm"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddInput("Scale")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsTensor()
+      .End()
+      .AddInput("Mean")
+      .IsTensor()
+      .End()
+      .AddInput("Variance")
+      .IsTensor()
+      .End()
+      .AddOutput("MeanOut")
+      .IsTensor()
+      .End()
+      .AddOutput("VarianceOut")
+      .IsTensor()
+      .End()
+      .AddOutput("SavedMean")
+      .IsTensor()
+      .End()
+      .AddOutput("SavedVariance")
+      .IsTensor()
+      .End()
+      .AddOutput("Y")
+      .IsTensor()
+      .End()
+      .AddAttr("epsilon")
+      .IsNumLE(0.001f)
+      .IsNumGE(0.0f)
+      .End();
+
+  AddOpCompat(OpCompat("elementwise_add"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddInput("Y")
+      .IsTensor()
+      .End()
+      .AddOutput("Out")
+      .IsTensor()
+      .End()
+      .AddAttr("axis")
+      .IsNumEQ(1)
+      .End();
+}
+
 void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
   PADDLE_ENFORCE_NOT_NULL(
       graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
@@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
   int found_conv_bn_count = 0;
   auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                      Graph* g) {
+    if (!IsCompat(subgraph, g)) {
+      LOG(WARNING) << "Pass in op compat failed.";
+      return;
+    }
     VLOG(4) << "handle " + conv_type() + "BN fuse";
-
     // conv, batch_norm,
     // conv_weight, conv_out,
     // bn_scale, bn_bias, bn_mean, bn_variance,
@@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
     // Update the elementwise_add node
     eltwise->Op()->SetAttr("axis", 1);
     eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
-
+    if (!IsCompat(*eltwise->Op())) {
+      LOG(WARNING)
+          << "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
+      return;
+    }
     GraphSafeRemoveNodes(
         graph,
         {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
@@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
   AddStatis(found_conv_bn_count);
 }
 
+ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
+  AddOpCompat(OpCompat("conv2d_transpose"))
+      .AddInput("Input")
+      .IsTensor()
+      .End()
+      .AddInput("Filter")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsOptional()
+      .End()
+      .AddOutput("Output")
+      .IsTensor()
+      .End()
+      .AddAttr("strides")
+      .End()
+      .AddAttr("paddings")
+      .End()
+      .AddAttr("padding_algorithm")
+      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
+      .IsOptional()
+      .End()
+      .AddAttr("groups")
+      .IsNumGE(1)
+      .End()
+      .AddAttr("dilations")
+      .End()
+      .AddAttr("data_format")
+      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
+      .End();
+}
+
+ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
+  AddOpCompat(OpCompat("conv2d_transpose"))
+      .AddInput("Input")
+      .IsTensor()
+      .End()
+      .AddInput("Filter")
+      .IsTensor()
+      .End()
+      .AddInput("Bias")
+      .IsOptional()
+      .End()
+      .AddOutput("Output")
+      .IsTensor()
+      .End()
+      .AddAttr("strides")
+      .End()
+      .AddAttr("paddings")
+      .End()
+      .AddAttr("padding_algorithm")
+      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
+      .IsOptional()
+      .End()
+      .AddAttr("groups")
+      .IsNumGE(1)
+      .End()
+      .AddAttr("dilations")
+      .End()
+      .AddAttr("data_format")
+      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
+      .End();
+}
+
 }  // namespace ir
 }  // namespace framework
 }  // namespace paddle
diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h
index 342cd8dad5f..c78dfc2a487 100644
--- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h
+++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h
@@ -31,6 +31,7 @@ class Graph;
 
 class ConvBNFusePass : public FusePassBase {
  public:
+  ConvBNFusePass();
   virtual ~ConvBNFusePass() {}
   virtual std::string conv_type() const { return "conv2d"; }
 
@@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase {
 
 class ConvEltwiseAddBNFusePass : public FusePassBase {
  public:
+  ConvEltwiseAddBNFusePass();
   virtual ~ConvEltwiseAddBNFusePass() {}
   virtual std::string conv_type() const { return "conv2d"; }
 
@@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
 
 class ConvTransposeBNFusePass : public ConvBNFusePass {
  public:
+  ConvTransposeBNFusePass();
+  virtual ~ConvTransposeBNFusePass() {}
   std::string conv_type() const { return "conv2d_transpose"; }
 };
 
 class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
  public:
+  ConvTransposeEltwiseAddBNFusePass();
+  virtual ~ConvTransposeEltwiseAddBNFusePass() {}
   std::string conv_type() const { return "conv2d_transpose"; }
 };
 
diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h
index 4b6068d4776..f5639e7bc9a 100644
--- a/paddle/fluid/framework/ir/pass_tester_helper.h
+++ b/paddle/fluid/framework/ir/pass_tester_helper.h
@@ -39,28 +39,49 @@ struct Layers {
   }
 
   VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
-                  bool use_cudnn = false) {
+                  int groups = 1, std::vector<int> strides = {1, 1},
+                  std::vector<int> paddings = {0, 0},
+                  std::string padding_algorithm = "EXPLICIT",
+                  std::vector<int> dilations = {1, 1},
+                  std::string data_format = "NCHW", bool use_cudnn = false) {
     VarDesc* out = lod_tensor(unique_name());
     OpDesc* op = program_.MutableBlock(0)->AppendOp();
     op->SetType("conv2d");
     op->SetInput("Input", {input->Name()});
     op->SetInput("Filter", {filter->Name()});
     op->SetInput("Bias", {bias->Name()});
-    op->SetOutput("Out", {out->Name()});
+    op->SetOutput("Output", {out->Name()});
     op->SetAttr("use_cudnn", use_cudnn);
+    op->SetAttr("groups", groups);
+    op->SetAttr("strides", strides);
+    op->SetAttr("paddings", paddings);
+    op->SetAttr("padding_algorithm", padding_algorithm);
+    op->SetAttr("dilations", dilations);
+    op->SetAttr("data_format", data_format);
     op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
                 static_cast<int>(OpRole::kForward));
     return out;
   }
 
-  VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) {
+  VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias,
+                            int groups = 1, std::vector<int> strides = {1, 1},
+                            std::vector<int> paddings = {0, 0},
+                            std::string padding_algorithm = "EXPLICIT",
+                            std::vector<int> dilations = {1, 1},
+                            std::string data_format = "NCHW") {
     VarDesc* out = lod_tensor(unique_name());
     OpDesc* op = program_.MutableBlock(0)->AppendOp();
     op->SetType("conv2d_transpose");
     op->SetInput("Input", {input->Name()});
     op->SetInput("Filter", {filter->Name()});
     op->SetInput("Bias", {bias->Name()});
-    op->SetOutput("Out", {out->Name()});
+    op->SetOutput("Output", {out->Name()});
+    op->SetAttr("groups", groups);
+    op->SetAttr("strides", strides);
+    op->SetAttr("paddings", paddings);
+    op->SetAttr("padding_algorithm", padding_algorithm);
+    op->SetAttr("dilations", dilations);
+    op->SetAttr("data_format", data_format);
     op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
                 static_cast<int>(OpRole::kForward));
     return out;
diff --git a/paddle/fluid/operators/compat/batch_norm.pbtxt b/paddle/fluid/operators/compat/batch_norm.pbtxt
index ac2ccc6296c..ed6162fb91c 100644
--- a/paddle/fluid/operators/compat/batch_norm.pbtxt
+++ b/paddle/fluid/operators/compat/batch_norm.pbtxt
@@ -42,6 +42,10 @@ extra {
   inputs {
     name: "MomentumTensor"
   }
+   attrs {
+    name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
+    type: BOOLEAN
+  } 
   attrs {
     name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
     type: BOOLEAN
diff --git a/paddle/fluid/operators/compat/conv2d.pbtxt b/paddle/fluid/operators/compat/conv2d.pbtxt
index ae4381bbc43..d8a08b6b410 100644
--- a/paddle/fluid/operators/compat/conv2d.pbtxt
+++ b/paddle/fluid/operators/compat/conv2d.pbtxt
@@ -41,6 +41,14 @@ def {
   }
 }
 extra {
+  attrs {
+    name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
+    type: BOOLEAN
+  } 
+  attrs {
+    name: "skip_quant"
+    type: BOOLEAN
+  }
   attrs {
     name: "is_test"
     type: BOOLEAN
diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt
index bd0e9988010..271ed91718c 100644
--- a/paddle/fluid/operators/compat/relu.pbtxt
+++ b/paddle/fluid/operators/compat/relu.pbtxt
@@ -12,6 +12,14 @@ extra {
     name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
     type: BOOLEAN
   }
+  attrs {
+    name: "out_threshold"
+    type: FLOAT
+  }
+  attrs {
+    name: "Out0_threshold"
+    type: FLOAT
+  }
   attrs {
     name: "use_mkldnn"
     type: BOOLEAN
-- 
GitLab