未验证 提交 e47d8a57 编写于 作者: 王明冬 提交者: GitHub

[pass_enhance]fix the mkldnn model performance drop problem. test=develop (#34625)

上级 4d6f8f2a
...@@ -60,6 +60,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { ...@@ -60,6 +60,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() {
.IsStringIn({"NHWC", "NCHW"}) .IsStringIn({"NHWC", "NCHW"})
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End(); .End();
} }
......
...@@ -120,6 +120,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -120,6 +120,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
...@@ -129,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -129,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("affine_channel")) AddOpCompat(OpCompat("affine_channel"))
...@@ -267,6 +268,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { ...@@ -267,6 +268,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
...@@ -276,7 +278,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { ...@@ -276,7 +278,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("affine_channel")) AddOpCompat(OpCompat("affine_channel"))
.AddInput("X") .AddInput("X")
......
...@@ -620,6 +620,7 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() { ...@@ -620,6 +620,7 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
...@@ -663,6 +664,7 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() { ...@@ -663,6 +664,7 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
......
...@@ -68,6 +68,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { ...@@ -68,6 +68,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddAttr("paddings") .AddAttr("paddings")
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
......
...@@ -47,6 +47,7 @@ ConvBiasFusePass::ConvBiasFusePass() { ...@@ -47,6 +47,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
...@@ -56,7 +57,7 @@ ConvBiasFusePass::ConvBiasFusePass() { ...@@ -56,7 +57,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -110,6 +111,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -110,6 +111,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
...@@ -135,6 +137,7 @@ Conv3DBiasFusePass::Conv3DBiasFusePass() { ...@@ -135,6 +137,7 @@ Conv3DBiasFusePass::Conv3DBiasFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
......
...@@ -158,11 +158,6 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( ...@@ -158,11 +158,6 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
Node* elementwise_add_op; Node* elementwise_add_op;
Node* elementwise_add_identity; Node* elementwise_add_identity;
Node* elementwise_add_out; Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_op, conv_input, conv_filter, conv_output) = std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph); get_node_from_conv_op(subgraph);
...@@ -175,6 +170,12 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( ...@@ -175,6 +170,12 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
......
...@@ -77,7 +77,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { ...@@ -77,7 +77,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsOptional() .IsOptional()
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
} }
......
...@@ -243,6 +243,7 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -243,6 +243,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册