diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 96c5546d21208b2708774a78bb7fe693b9a440a5..c2ee2fc6b32e797c63e0ced08caf346fa6ac221d 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -327,9 +327,8 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(quant_conv2d_dequant_fuse_pass, paddle::framework::ir::QuantDequantFusePass); -REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); -REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) +REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) @@ -338,5 +337,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .EQ("fake_quantize_abs_max", 0) .EQ("fake_quantize_range_abs_max", 0) .EQ("fake_quantize_moving_average_abs_max", 0) - .EQ("fake_channel_wise_quantize_abs_max", 0) + .LE("fake_channel_wise_quantize_abs_max", 1) .EQ("fake_dequantize_max_abs", 0)); diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 125346cb22789fd194ce2846f026702a723b4999..d8321939f6c61d55c68055dc0d03fa0153489379 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -240,7 +240,13 @@ class OpVersionComparator { if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \ version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \ } \ - return version_id cmp_math target_version_; \ + bool check_ok = version_id cmp_math target_version_; \ + if (!check_ok) { \ + LOG(WARNING) << "Check op version in pass failed. op name:" \ + << op_name_.c_str() << " op_version:" << version_id \ + << " target_version:" << target_version_; \ + } \ + return check_ok; \ } \ virtual ~OpVersion##cmp_name##Comparator() {} \ \ @@ -326,6 +332,11 @@ class PassVersionCheckerRegistrar { return instance; } PassVersionCheckers& Register(const std::string& pass_name) { + PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name), + pass_version_checkers_map_.end(), + platform::errors::AlreadyExists( + "PassVersionCheckers(%s) has alredy been registered.", + pass_name.c_str())); return pass_version_checkers_map_[pass_name]; } bool IsPassCompatible(const std::string& fuse_pass_name) const {