未验证 提交 e47d8a57 编写于 作者: 王明冬 提交者: GitHub

[pass_enhance]fix the mkldnn model performance drop problem. test=develop (#34625)

上级 4d6f8f2a
......@@ -60,6 +60,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() {
.IsStringIn({"NHWC", "NCHW"})
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End();
}
......
......@@ -120,6 +120,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......@@ -129,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
......@@ -267,6 +268,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......@@ -276,7 +278,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
......
......@@ -620,6 +620,7 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
......@@ -663,6 +664,7 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
......
......@@ -68,6 +68,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......
......@@ -47,6 +47,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......@@ -56,7 +57,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
......@@ -110,6 +111,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
......@@ -135,6 +137,7 @@ Conv3DBiasFusePass::Conv3DBiasFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......
......@@ -158,11 +158,6 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph);
......@@ -175,6 +170,12 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (HasFusedActivation(conv_op)) return;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
......
......@@ -77,7 +77,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.End()
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "NHWC"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
......
......@@ -243,6 +243,7 @@ QuantDequantFusePass::QuantDequantFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册