未验证 提交 c95c35a2 编写于 作者: G gem5 提交者: GitHub

conv2d_fusion(cudnn or cutlass) (#49707)

上级 3fb4a08c
...@@ -155,7 +155,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -155,7 +155,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
} }
#endif #endif
if (!(is_fp16_precision && cutlass_enable)) return; if (!is_fp16_precision) return;
PADDLE_ENFORCE_EQ(graph->IsMainGraph(), PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true, true,
...@@ -180,16 +180,31 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -180,16 +180,31 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
std::string target_op_type = "conv2d_fusion"; std::string target_op_type = "conv2d_fusion";
std::unordered_set<ir::Node *> valid_ops; std::unordered_set<ir::Node *> valid_ops;
auto OpIsValid = [&](ir::Node *op_node) -> bool { auto cuDNNIsValid = [&](ir::Node *op_node) -> bool {
if (op_node->Op()->Type() != target_op_type) return false; if (op_node->Op()->Type() != target_op_type) return false;
auto data_format = auto data_format =
op_node->Op()->GetAttrIfExists<std::string>("data_format"); op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") return false; if (data_format != "NCHW") return false;
auto filter_names = op_node->Op()->Input("Filter"); auto filter_names = op_node->Op()->Input("Filter");
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8; constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
return false;
}
}
return true;
};
auto CutlassIsValid = [&](ir::Node *op_node) -> bool {
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
// conv2d_fusion has two forms: conv + bias + act, conv + bias + // conv2d_fusion has two forms: conv + bias + act, conv + bias +
// elmentwise_add + act. // elmentwise_add + act.
std::unordered_set<std::string> cutlass_cba_act_set = { std::unordered_set<std::string> cutlass_cba_act_set = {
...@@ -206,31 +221,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -206,31 +221,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
return false; return false;
} }
} }
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
return false;
}
}
return true; return true;
}; };
for (auto *op_node : op_nodes) { for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true); CHECK_EQ(op_node->IsOp(), true);
if (OpIsValid(op_node)) { if (cuDNNIsValid(op_node)) {
valid_ops.insert(op_node); valid_ops.insert(op_node);
auto *op_desc = op_node->Op(); auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC")); auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr); op_desc->SetAttr("data_format", nhwc_attr);
op_desc->SetType("conv2d_fusion_cutlass"); if (cutlass_enable && CutlassIsValid(op_node)) {
op_desc->SetType("conv2d_fusion_cutlass");
}
op_desc->Flush(); op_desc->Flush();
// transfer weights // transfer weights
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册