diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 76c6ca24aaaf062dad71ff3d39aef77d74582ba3..716c49dcb12d9b432dfddd54d1dc3fa33570f26f 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -158,7 +158,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d_transpose", 0) + .LE("conv2d_transpose", 1) .EQ("elementwise_add", 0)); REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 96f88e70a98d453cbf7d74e6f17194855003555a..895c396e1e614fb06c37d519b45c942429bbf9a2 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -326,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(quant_conv2d_dequant_fuse_pass, paddle::framework::ir::QuantDequantFusePass); +REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index c9d3084724bcdbcee6e0e43a985d5c41c5a8ae84..c121e6429dbb414ceb3773ad62f23951f525626d 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -308,7 +308,7 @@ class PassVersionCheckerRegistrar { bool IsPassCompatible(const std::string& fuse_pass_name) const { auto iter = pass_version_checkers_map_.find(fuse_pass_name); if (iter == pass_version_checkers_map_.end()) { - return true; + return false; } return iter->second.IsPassCompatible(); } diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc index ef8384c1e7ee1d58f1e8e8cfda6d0ae54fc756ed..888dd6de0618ba1d996983eb4229f53899dc2e86 100644 --- a/paddle/fluid/framework/op_version_registry_test.cc +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -57,6 +57,10 @@ TEST(test_operator_version, test_operator_version) { TEST(test_pass_op_version_checker, test_pass_op_version_checker) { const std::string fake_op_name{"op_name__"}; + ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( + "no_registered_capability_pass")); + + REGISTER_PASS_CAPABILITY(no_bind_pass); ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( "no_bind_pass")); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_bias_fuse_pass.py similarity index 100% rename from python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py rename to python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_bias_fuse_pass.py