From 0f6c5459da5c7db25a667c49fd17a5026ba97f1d Mon Sep 17 00:00:00 2001
From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com>
Date: Fri, 9 Dec 2022 13:50:23 +0800
Subject: [PATCH] [Paddle Inference]add cutlass act set in
 conv_elementwise_add_act_fuse_pass (#48838)

* add cutlass act set in conv_elementwise_add_act_fuse_pass
---
 .../ir/conv2d_fusion_layout_transfer_pass.cc  | 30 +++++++-
 .../ir/conv_elementwise_add_act_fuse_pass.cc  | 70 +++++++++++++++++--
 2 files changed, 93 insertions(+), 7 deletions(-)

diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
index 7ac8096bb91..dbba001d521 100644
--- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
+++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
@@ -138,6 +138,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
   FusePassBase::Init("data_layout_transfer", graph);
   auto *scope = param_scope();
 
+  // only float16 compute precision need insert transfer_layout.
+  bool is_fp16_precision =
+      static_cast<phi::DataType>(Get<int>("model_precision")) ==
+          phi::DataType::FLOAT16 ||
+      Get<bool>("enable_gpu_half");
+  bool cutlass_enable = false;
+
+#ifdef PADDLE_WITH_CUTLASS
+  cutlass_enable = true;
+#endif
+
+  if (!(is_fp16_precision && cutlass_enable)) return;
+
   PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
                     true,
                     platform::errors::InvalidArgument(
@@ -169,14 +182,24 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
     if (data_format != "NCHW") return false;
 
     auto filter_names = op_node->Op()->Input("Filter");
+    auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
+    constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
+    std::unordered_set<std::string> cutlass_act_set = {
+        "relu", "swish", "identity", "leaky_relu"};
+    if (!cutlass_act_set.count(act_type)) {
+      return false;
+    }
 
     // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
     for (const auto &filter_name : filter_names) {
       auto *filter_var = scope->FindLocalVar(filter_name);
       const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
-      if (filter_tensor.dims().size() == 4 &&
-          (filter_tensor.dims()[0] % 8 != 0 ||
-           filter_tensor.dims()[1] % 8 != 0)) {
+      CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
+      int oc = filter_tensor.dims()[0];
+      int ic = filter_tensor.dims()[1];
+      bool cutlass_can_support =
+          oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
+      if (!cutlass_can_support) {
         return false;
       }
     }
@@ -190,6 +213,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
       auto *op_desc = op_node->Op();
       auto nhwc_attr = framework::Attribute(std::string("NHWC"));
       op_desc->SetAttr("data_format", nhwc_attr);
+      op_desc->SetType("conv2d_fusion_cutlass");
       op_desc->Flush();
 
       // transfer weights
diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
index 1d309d13379..063eb90d90a 100644
--- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
@@ -36,7 +36,8 @@ framework::proto::OpDesc PrepareOpDesc(
     const framework::proto::OpDesc& base_desc,
     const std::string& bias,
     const std::string& activation,
-    const std::string& output) {
+    const std::string& output,
+    float alpha) {
   auto proto = base_desc;
   framework::OpDesc desc(proto, nullptr);
   desc.SetType("conv2d_fusion");
@@ -46,6 +47,8 @@ framework::proto::OpDesc PrepareOpDesc(
   desc.SetOutput("Output", {output});
   desc.SetAttr("is_test", true);
   desc.SetAttr("use_cudnn", false);
+  // for leaky_relu use
+  desc.SetAttr("fuse_alpha", alpha);
   desc.Flush();
   return *desc.Proto();
 }
@@ -118,6 +121,25 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
       .AddOutput("Out")
       .IsTensor()
       .End();
+
+  AddOpCompat(OpCompat("swish"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddOutput("Out")
+      .IsTensor()
+      .End();
+
+  AddOpCompat(OpCompat("leaky_relu"))
+      .AddInput("X")
+      .IsTensor()
+      .End()
+      .AddAttr("alpha")
+      .IsType<float>()
+      .End()
+      .AddOutput("Out")
+      .IsTensor()
+      .End();
 }
 
 void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
@@ -137,8 +159,28 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
   std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
 #endif
 
+  std::unordered_set<std::string> cutlass_act_set;
+  std::unordered_set<std::string> all_act_set = cudnn_act_set;
+
+  bool is_fp16_precision =
+      static_cast<phi::DataType>(Get<int>("model_precision")) ==
+          phi::DataType::FLOAT16 ||
+      Get<bool>("enable_gpu_half");
+  constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
+  if (is_fp16_precision) {
+#ifdef PADDLE_WITH_CUTLASS
+    // cutlass now support these activations
+    // cutlass_act_set.insert("swish");
+    // cutlass_act_set.insert("relu");
+    // cutlass_act_set.insert("identity");
+    // cutlass_act_set.insert("leaky_relu");
+
+    all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
+#endif
+  }
+
   patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
-  pattern(x, cudnn_act_set);
+  pattern(x, all_act_set);
 
   auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                      Graph* g) {
@@ -152,9 +194,27 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
     std::string bias_name = elementwise_add_in_y->Name();
     std::string act_op_type = act_op->Op()->Type();
     std::string act_op_out = act_out->Name();
+    auto* scope = param_scope();
+    auto* filter_var = scope->FindLocalVar(conv_filter->Name());
+    auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
+    CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
+    // when this conv2d_fusion problem size is not supported by cutlass and not
+    // supported by cuDNN, we should not apply this pass
+    int oc = filter_tensor->dims()[0];
+    int ic = filter_tensor->dims()[1];
+    bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 &&
+                            ic % CUTLASS_NHWC_ALIGNMENT == 0 &&
+                            cutlass_act_set.count(act_op_type);
+    bool cudnn_can_fuse = cudnn_act_set.count(act_op_type);
+    if (!cutlass_can_fuse && !cudnn_can_fuse) {
+      return;
+    }
+
+    float alpha = 0.f;
+    alpha = act_op->Op()->GetAttrIfExists<float>("alpha");
 
     auto new_op_proto =
-        PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out);
+        PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out, alpha);
     framework::OpDesc new_op_desc(new_op_proto, nullptr);
 
     // Create a new node for the fused op.
@@ -195,4 +255,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
             .EQ("relu", 0)
             .EQ("sigmoid", 0)
             .EQ("tanh", 0)
-            .EQ("identity", 0));
+            .EQ("identity", 0)
+            .LE("leaky_relu", 1)
+            .EQ("swish", 0));
-- 
GitLab