未验证 提交 0fc181db 编写于 作者: L lidanqing 提交者: GitHub

[Fix bug] If the pass name is not found, IsCompatible should return false (#28475)

上级 b258caf4
...@@ -158,7 +158,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, ...@@ -158,7 +158,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d_transpose", 0) .LE("conv2d_transpose", 1)
.EQ("elementwise_add", 0)); .EQ("elementwise_add", 0));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
......
...@@ -326,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -326,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(quant_conv2d_dequant_fuse_pass, REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass); paddle::framework::ir::QuantDequantFusePass);
REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination( .AddCombination(
......
...@@ -308,7 +308,7 @@ class PassVersionCheckerRegistrar { ...@@ -308,7 +308,7 @@ class PassVersionCheckerRegistrar {
bool IsPassCompatible(const std::string& fuse_pass_name) const { bool IsPassCompatible(const std::string& fuse_pass_name) const {
auto iter = pass_version_checkers_map_.find(fuse_pass_name); auto iter = pass_version_checkers_map_.find(fuse_pass_name);
if (iter == pass_version_checkers_map_.end()) { if (iter == pass_version_checkers_map_.end()) {
return true; return false;
} }
return iter->second.IsPassCompatible(); return iter->second.IsPassCompatible();
} }
......
...@@ -57,6 +57,10 @@ TEST(test_operator_version, test_operator_version) { ...@@ -57,6 +57,10 @@ TEST(test_operator_version, test_operator_version) {
TEST(test_pass_op_version_checker, test_pass_op_version_checker) { TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"}; const std::string fake_op_name{"op_name__"};
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_registered_capability_pass"));
REGISTER_PASS_CAPABILITY(no_bind_pass);
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass")); "no_bind_pass"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册