From a0a907989cc401a9d583b01c4ee9838e42c9f0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 1 Jul 2021 11:20:13 +0800 Subject: [PATCH] add compat precondition for conv3d_bias_mkldnn_fuse_pass, test=develop (#33839) --- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 36 +++++++++++++++---- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.h | 1 + 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 8d73a35bf09..c03d6a582e4 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -41,8 +41,10 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsTensor() .End() .AddAttr("strides") + .IsType>() .End() .AddAttr("paddings") + .IsType>() .End() .AddAttr("padding_algorithm") .IsStringIn({"EXPLICIT", "SAME", "VALID"}) @@ -51,6 +53,7 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsNumGE(1) .End() .AddAttr("dilations") + .IsType>() .End() .AddAttr("data_format") .IsStringIn({"NCHW", "NHWC"}) @@ -86,6 +89,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsTensor() .End() .AddAttr("output_padding") + .IsType>() .End() .AddAttr("output_size") .IsNumGE(1) @@ -94,10 +98,13 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsNumGE(1) .End() .AddAttr("dilations") + .IsType>() .End() .AddAttr("strides") + .IsType>() .End() .AddAttr("paddings") + .IsType>() .End() .AddAttr("padding_algorithm") .IsStringIn({"EXPLICIT", "SAME", "VALID"}) @@ -105,19 +112,36 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .AddAttr("data_format") .IsStringIn({"NCHW", "NHWC"}) .End(); +} - AddOpCompat(OpCompat("elementwise_add")) - .AddInput("X") +Conv3DBiasFusePass::Conv3DBiasFusePass() { + AddOpCompat(OpCompat("conv3d")) + .AddInput("Input") .IsTensor() .End() - .AddInput("Y") + .AddInput("Filter") .IsTensor() .End() - .AddOutput("Out") + .AddOutput("Output") .IsTensor() .End() - .AddAttr("axis") - .IsNumEQ(-1) + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 20c683c094e..a74d7443ee1 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -48,6 +48,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { class Conv3DBiasFusePass : public ConvBiasFusePass { public: + Conv3DBiasFusePass(); std::string type() const override { return "conv3d"; } }; } // namespace ir -- GitLab