未验证 提交 a0a90798 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for conv3d_bias_mkldnn_fuse_pass, test=develop (#33839)

上级 cfe3b40a
...@@ -41,8 +41,10 @@ ConvBiasFusePass::ConvBiasFusePass() { ...@@ -41,8 +41,10 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
...@@ -51,6 +53,7 @@ ConvBiasFusePass::ConvBiasFusePass() { ...@@ -51,6 +53,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC"})
...@@ -86,6 +89,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -86,6 +89,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("output_padding") .AddAttr("output_padding")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("output_size") .AddAttr("output_size")
.IsNumGE(1) .IsNumGE(1)
...@@ -94,10 +98,13 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -94,10 +98,13 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
...@@ -105,19 +112,36 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -105,19 +112,36 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC"})
.End(); .End();
}
AddOpCompat(OpCompat("elementwise_add")) Conv3DBiasFusePass::Conv3DBiasFusePass() {
.AddInput("X") AddOpCompat(OpCompat("conv3d"))
.AddInput("Input")
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Y") .AddInput("Filter")
.IsTensor() .IsTensor()
.End() .End()
.AddOutput("Out") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") .AddAttr("strides")
.IsNumEQ(-1) .IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End(); .End();
} }
......
...@@ -48,6 +48,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { ...@@ -48,6 +48,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass {
class Conv3DBiasFusePass : public ConvBiasFusePass { class Conv3DBiasFusePass : public ConvBiasFusePass {
public: public:
Conv3DBiasFusePass();
std::string type() const override { return "conv3d"; } std::string type() const override { return "conv3d"; }
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册