diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 74bbe24eb82f5d3acd16ef6d51e71cdc77341544..a7514038d400b66cb1253242841186cfbb31610c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -70,7 +70,7 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({-1, 0}) + .IsIntIn({1, 3}) .End(); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 80a9ef7eda724a49046f636f0617cbccf51c68a2..e41c35ba33fdc91fbd935efc611bb2b57fdb749c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -52,7 +52,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Bias", {}); } else if (type == "elementwise_add") { op->SetAttr("use_mkldnn", true); - op->SetAttr("axis", -1); + op->SetAttr("axis", 1); op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); op->SetOutput("Out", outputs);