未验证 提交 b47923b4 编写于 作者: S Sławomir Siwek 提交者: GitHub

Add bf16 support for fused matmul (#50254)

* add support for bf16 fused_ops

* fused_matmul only
上级 209d534d
...@@ -2819,6 +2819,7 @@ PDNode *patterns::Bfloat16Placement::operator()( ...@@ -2819,6 +2819,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
"layer_norm", "layer_norm",
"matmul", "matmul",
"matmul_v2", "matmul_v2",
"fused_matmul",
"pool2d", "pool2d",
"prelu", "prelu",
"relu", "relu",
......
...@@ -146,50 +146,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { ...@@ -146,50 +146,6 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("abs")) AddOpCompat(OpCompat("abs"))
......
...@@ -150,50 +150,6 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() { ...@@ -150,50 +150,6 @@ MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
......
...@@ -174,50 +174,6 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { ...@@ -174,50 +174,6 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("transpose2")) AddOpCompat(OpCompat("transpose2"))
......
...@@ -265,50 +265,6 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { ...@@ -265,50 +265,6 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.End() .End()
.AddAttr("trans_y") .AddAttr("trans_y")
.IsType<bool>() .IsType<bool>()
.End()
.AddAttr("matmul_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_activation")
.IsType<std::string>()
.IsOptional()
.End()
.AddAttr("fuse_alpha")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fuse_beta")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_output_scale")
.IsType<float>()
.IsOptional()
.End()
.AddAttr("fused_reshape_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_X")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Y")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_reshape_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("fused_transpose_Out")
.IsType<std::vector<int>>()
.IsOptional()
.End(); .End();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册