From b47923b46103a8b370481885ab365f6f5acd7538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Wed, 8 Feb 2023 14:01:08 +0100 Subject: [PATCH] Add bf16 support for fused matmul (#50254) * add support for bf16 fused_ops * fused_matmul only --- .../framework/ir/graph_pattern_detector.cc | 1 + .../matmul_activation_mkldnn_fuse_pass.cc | 44 ------------------- ...matmul_elementwise_add_mkldnn_fuse_pass.cc | 44 ------------------- ...tmul_transpose_reshape_mkldnn_fuse_pass.cc | 44 ------------------- ...shape_transpose_matmul_mkldnn_fuse_pass.cc | 44 ------------------- 5 files changed, 1 insertion(+), 176 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d9c351e288..6591ede1f6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2819,6 +2819,7 @@ PDNode *patterns::Bfloat16Placement::operator()( "layer_norm", "matmul", "matmul_v2", + "fused_matmul", "pool2d", "prelu", "relu", diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc index 61bd888715..76529d3d1a 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc @@ -146,50 +146,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { .End() .AddAttr("trans_y") .IsType() - .End() - .AddAttr("matmul_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_activation") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_beta") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_output_scale") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_reshape_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Out") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Out") - .IsType>() - .IsOptional() .End(); AddOpCompat(OpCompat("abs")) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc index 680600a403..4e6c64ca78 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc @@ -150,50 +150,6 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() { .End() .AddAttr("trans_y") .IsType() - .End() - .AddAttr("matmul_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_activation") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_beta") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_output_scale") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_reshape_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Out") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Out") - .IsType>() - .IsOptional() .End(); AddOpCompat(OpCompat("elementwise_add")) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc index 779c39834c..5bba8606f4 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.cc @@ -174,50 +174,6 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .End() .AddAttr("trans_y") .IsType() - .End() - .AddAttr("matmul_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_activation") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_beta") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_output_scale") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_reshape_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Out") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Out") - .IsType>() - .IsOptional() .End(); AddOpCompat(OpCompat("transpose2")) diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index 508cad94e8..487099e94e 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -265,50 +265,6 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { .End() .AddAttr("trans_y") .IsType() - .End() - .AddAttr("matmul_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_activation") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_alpha") - .IsType() - .IsOptional() - .End() - .AddAttr("fuse_beta") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_output_scale") - .IsType() - .IsOptional() - .End() - .AddAttr("fused_reshape_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_X") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Y") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_reshape_Out") - .IsType>() - .IsOptional() - .End() - .AddAttr("fused_transpose_Out") - .IsType>() - .IsOptional() .End(); } -- GitLab