diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 8d73a35bf09be83f6a99343a3941a01d66fbd5c1..c03d6a582e4312e5ab2709850a7d7bec6f228977 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -41,8 +41,10 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsTensor() .End() .AddAttr("strides") + .IsType>() .End() .AddAttr("paddings") + .IsType>() .End() .AddAttr("padding_algorithm") .IsStringIn({"EXPLICIT", "SAME", "VALID"}) @@ -51,6 +53,7 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsNumGE(1) .End() .AddAttr("dilations") + .IsType>() .End() .AddAttr("data_format") .IsStringIn({"NCHW", "NHWC"}) @@ -86,6 +89,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsTensor() .End() .AddAttr("output_padding") + .IsType>() .End() .AddAttr("output_size") .IsNumGE(1) @@ -94,10 +98,13 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsNumGE(1) .End() .AddAttr("dilations") + .IsType>() .End() .AddAttr("strides") + .IsType>() .End() .AddAttr("paddings") + .IsType>() .End() .AddAttr("padding_algorithm") .IsStringIn({"EXPLICIT", "SAME", "VALID"}) @@ -105,19 +112,36 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .AddAttr("data_format") .IsStringIn({"NCHW", "NHWC"}) .End(); +} - AddOpCompat(OpCompat("elementwise_add")) - .AddInput("X") +Conv3DBiasFusePass::Conv3DBiasFusePass() { + AddOpCompat(OpCompat("conv3d")) + .AddInput("Input") .IsTensor() .End() - .AddInput("Y") + .AddInput("Filter") .IsTensor() .End() - .AddOutput("Out") + .AddOutput("Output") .IsTensor() .End() - .AddAttr("axis") - .IsNumEQ(-1) + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 20c683c094edfea74f3220875c805a5f46a35c87..a74d7443ee1fe13212c6514d415a16d6f0cb2f5b 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -48,6 +48,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { class Conv3DBiasFusePass : public ConvBiasFusePass { public: + Conv3DBiasFusePass(); std::string type() const override { return "conv3d"; } }; } // namespace ir