diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index 7ac8096bb917772a6d670ef1ffa9f001961f6b1a..dbba001d52101565fb20b4244a11df6a2b91f0e1 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -138,6 +138,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { FusePassBase::Init("data_layout_transfer", graph); auto *scope = param_scope(); + // only float16 compute precision need insert transfer_layout. + bool is_fp16_precision = + static_cast(Get("model_precision")) == + phi::DataType::FLOAT16 || + Get("enable_gpu_half"); + bool cutlass_enable = false; + +#ifdef PADDLE_WITH_CUTLASS + cutlass_enable = true; +#endif + + if (!(is_fp16_precision && cutlass_enable)) return; + PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true, platform::errors::InvalidArgument( @@ -169,14 +182,24 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { if (data_format != "NCHW") return false; auto filter_names = op_node->Op()->Input("Filter"); + auto act_type = op_node->Op()->GetAttrIfExists("activation"); + constexpr int CUTLASS_NHWC_ALIGNMENT = 8; + std::unordered_set cutlass_act_set = { + "relu", "swish", "identity", "leaky_relu"}; + if (!cutlass_act_set.count(act_type)) { + return false; + } // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. for (const auto &filter_name : filter_names) { auto *filter_var = scope->FindLocalVar(filter_name); const auto &filter_tensor = filter_var->Get(); - if (filter_tensor.dims().size() == 4 && - (filter_tensor.dims()[0] % 8 != 0 || - filter_tensor.dims()[1] % 8 != 0)) { + CHECK_EQ(filter_tensor.dims().size() == 4UL, true); + int oc = filter_tensor.dims()[0]; + int ic = filter_tensor.dims()[1]; + bool cutlass_can_support = + oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0; + if (!cutlass_can_support) { return false; } } @@ -190,6 +213,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { auto *op_desc = op_node->Op(); auto nhwc_attr = framework::Attribute(std::string("NHWC")); op_desc->SetAttr("data_format", nhwc_attr); + op_desc->SetType("conv2d_fusion_cutlass"); op_desc->Flush(); // transfer weights diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 1d309d133795c5d7f7ccceb3e0177b41c37b1246..063eb90d90af17f11ebb826b19d0ec8b2065721b 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -36,7 +36,8 @@ framework::proto::OpDesc PrepareOpDesc( const framework::proto::OpDesc& base_desc, const std::string& bias, const std::string& activation, - const std::string& output) { + const std::string& output, + float alpha) { auto proto = base_desc; framework::OpDesc desc(proto, nullptr); desc.SetType("conv2d_fusion"); @@ -46,6 +47,8 @@ framework::proto::OpDesc PrepareOpDesc( desc.SetOutput("Output", {output}); desc.SetAttr("is_test", true); desc.SetAttr("use_cudnn", false); + // for leaky_relu use + desc.SetAttr("fuse_alpha", alpha); desc.Flush(); return *desc.Proto(); } @@ -118,6 +121,25 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { .AddOutput("Out") .IsTensor() .End(); + + AddOpCompat(OpCompat("swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("leaky_relu")) + .AddInput("X") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddOutput("Out") + .IsTensor() + .End(); } void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { @@ -137,8 +159,28 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { std::unordered_set cudnn_act_set({"identity", "relu"}); #endif + std::unordered_set cutlass_act_set; + std::unordered_set all_act_set = cudnn_act_set; + + bool is_fp16_precision = + static_cast(Get("model_precision")) == + phi::DataType::FLOAT16 || + Get("enable_gpu_half"); + constexpr int CUTLASS_NHWC_ALIGNMENT = 8; + if (is_fp16_precision) { +#ifdef PADDLE_WITH_CUTLASS + // cutlass now support these activations + // cutlass_act_set.insert("swish"); + // cutlass_act_set.insert("relu"); + // cutlass_act_set.insert("identity"); + // cutlass_act_set.insert("leaky_relu"); + + all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end()); +#endif + } + patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name); - pattern(x, cudnn_act_set); + pattern(x, all_act_set); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -152,9 +194,27 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { std::string bias_name = elementwise_add_in_y->Name(); std::string act_op_type = act_op->Op()->Type(); std::string act_op_out = act_out->Name(); + auto* scope = param_scope(); + auto* filter_var = scope->FindLocalVar(conv_filter->Name()); + auto* filter_tensor = filter_var->GetMutable(); + CHECK_EQ(filter_tensor->dims().size() == 4UL, true); + // when this conv2d_fusion problem size is not supported by cutlass and not + // supported by cuDNN, we should not apply this pass + int oc = filter_tensor->dims()[0]; + int ic = filter_tensor->dims()[1]; + bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 && + ic % CUTLASS_NHWC_ALIGNMENT == 0 && + cutlass_act_set.count(act_op_type); + bool cudnn_can_fuse = cudnn_act_set.count(act_op_type); + if (!cutlass_can_fuse && !cudnn_can_fuse) { + return; + } + + float alpha = 0.f; + alpha = act_op->Op()->GetAttrIfExists("alpha"); auto new_op_proto = - PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out); + PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out, alpha); framework::OpDesc new_op_desc(new_op_proto, nullptr); // Create a new node for the fused op. @@ -195,4 +255,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .EQ("relu", 0) .EQ("sigmoid", 0) .EQ("tanh", 0) - .EQ("identity", 0)); + .EQ("identity", 0) + .LE("leaky_relu", 1) + .EQ("swish", 0));