未验证 提交 a0a90798 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for conv3d_bias_mkldnn_fuse_pass, test=develop (#33839)

上级 cfe3b40a
......@@ -41,8 +41,10 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
......@@ -51,6 +53,7 @@ ConvBiasFusePass::ConvBiasFusePass() {
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
......@@ -86,6 +89,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.End()
.AddAttr("output_size")
.IsNumGE(1)
......@@ -94,10 +98,13 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
......@@ -105,19 +112,36 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
}
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
Conv3DBiasFusePass::Conv3DBiasFusePass() {
AddOpCompat(OpCompat("conv3d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Y")
.AddInput("Filter")
.IsTensor()
.End()
.AddOutput("Out")
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
}
......
......@@ -48,6 +48,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass {
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
Conv3DBiasFusePass();
std::string type() const override { return "conv3d"; }
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册