From c95c35a29bcd5340491d9b41c4fbf9c118fc94df Mon Sep 17 00:00:00 2001 From: gem5 <117625383+linsheng011@users.noreply.github.com> Date: Thu, 12 Jan 2023 15:58:27 +0800 Subject: [PATCH] conv2d_fusion(cudnn or cutlass) (#49707) --- .../ir/conv2d_fusion_layout_transfer_pass.cc | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index dd4e0735600..8e7d435cb5a 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -155,7 +155,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { } #endif - if (!(is_fp16_precision && cutlass_enable)) return; + if (!is_fp16_precision) return; PADDLE_ENFORCE_EQ(graph->IsMainGraph(), true, @@ -180,16 +180,31 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { std::string target_op_type = "conv2d_fusion"; std::unordered_set valid_ops; - auto OpIsValid = [&](ir::Node *op_node) -> bool { + auto cuDNNIsValid = [&](ir::Node *op_node) -> bool { if (op_node->Op()->Type() != target_op_type) return false; - auto data_format = op_node->Op()->GetAttrIfExists("data_format"); if (data_format != "NCHW") return false; - auto filter_names = op_node->Op()->Input("Filter"); - auto act_type = op_node->Op()->GetAttrIfExists("activation"); constexpr int CUTLASS_NHWC_ALIGNMENT = 8; + // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. + for (const auto &filter_name : filter_names) { + auto *filter_var = scope->FindLocalVar(filter_name); + const auto &filter_tensor = filter_var->Get(); + CHECK_EQ(filter_tensor.dims().size() == 4UL, true); + int oc = filter_tensor.dims()[0]; + int ic = filter_tensor.dims()[1]; + bool cutlass_can_support = + oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0; + if (!cutlass_can_support) { + return false; + } + } + return true; + }; + + auto CutlassIsValid = [&](ir::Node *op_node) -> bool { + auto act_type = op_node->Op()->GetAttrIfExists("activation"); // conv2d_fusion has two forms: conv + bias + act, conv + bias + // elmentwise_add + act. std::unordered_set cutlass_cba_act_set = { @@ -206,31 +221,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { return false; } } - - // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. - for (const auto &filter_name : filter_names) { - auto *filter_var = scope->FindLocalVar(filter_name); - const auto &filter_tensor = filter_var->Get(); - CHECK_EQ(filter_tensor.dims().size() == 4UL, true); - int oc = filter_tensor.dims()[0]; - int ic = filter_tensor.dims()[1]; - bool cutlass_can_support = - oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0; - if (!cutlass_can_support) { - return false; - } - } return true; }; for (auto *op_node : op_nodes) { CHECK_EQ(op_node->IsOp(), true); - if (OpIsValid(op_node)) { + if (cuDNNIsValid(op_node)) { valid_ops.insert(op_node); auto *op_desc = op_node->Op(); auto nhwc_attr = framework::Attribute(std::string("NHWC")); op_desc->SetAttr("data_format", nhwc_attr); - op_desc->SetType("conv2d_fusion_cutlass"); + if (cutlass_enable && CutlassIsValid(op_node)) { + op_desc->SetType("conv2d_fusion_cutlass"); + } op_desc->Flush(); // transfer weights -- GitLab