未验证 提交 0fc181db 编写于 作者: L lidanqing 提交者: GitHub

[Fix bug] If the pass name is not found, IsCompatible should return false (#28475)

上级 b258caf4
......@@ -158,7 +158,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d_transpose", 0)
.LE("conv2d_transpose", 1)
.EQ("elementwise_add", 0));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
......
......@@ -326,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass);
REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
......
......@@ -308,7 +308,7 @@ class PassVersionCheckerRegistrar {
bool IsPassCompatible(const std::string& fuse_pass_name) const {
auto iter = pass_version_checkers_map_.find(fuse_pass_name);
if (iter == pass_version_checkers_map_.end()) {
return true;
return false;
}
return iter->second.IsPassCompatible();
}
......
......@@ -57,6 +57,10 @@ TEST(test_operator_version, test_operator_version) {
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"};
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_registered_capability_pass"));
REGISTER_PASS_CAPABILITY(no_bind_pass);
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册