From 08dc5bc27e3fc1e3822e73c36d8a66b53daa5118 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Mon, 4 Jan 2021 16:11:16 +0800 Subject: [PATCH] fix op version checker of pass bug (#30028) * fix op version checker of pass bug * fix code style * update pass version --- .../framework/ir/quant_conv2d_dequant_fuse_pass.cc | 5 ++--- paddle/fluid/framework/op_version_registry.h | 13 ++++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 96c5546d21..c2ee2fc6b3 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -327,9 +327,8 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(quant_conv2d_dequant_fuse_pass, paddle::framework::ir::QuantDequantFusePass); -REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); -REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) +REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) @@ -338,5 +337,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .EQ("fake_quantize_abs_max", 0) .EQ("fake_quantize_range_abs_max", 0) .EQ("fake_quantize_moving_average_abs_max", 0) - .EQ("fake_channel_wise_quantize_abs_max", 0) + .LE("fake_channel_wise_quantize_abs_max", 1) .EQ("fake_dequantize_max_abs", 0)); diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 125346cb22..d8321939f6 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -240,7 +240,13 @@ class OpVersionComparator { if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \ version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \ } \ - return version_id cmp_math target_version_; \ + bool check_ok = version_id cmp_math target_version_; \ + if (!check_ok) { \ + LOG(WARNING) << "Check op version in pass failed. op name:" \ + << op_name_.c_str() << " op_version:" << version_id \ + << " target_version:" << target_version_; \ + } \ + return check_ok; \ } \ virtual ~OpVersion##cmp_name##Comparator() {} \ \ @@ -326,6 +332,11 @@ class PassVersionCheckerRegistrar { return instance; } PassVersionCheckers& Register(const std::string& pass_name) { + PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name), + pass_version_checkers_map_.end(), + platform::errors::AlreadyExists( + "PassVersionCheckers(%s) has alredy been registered.", + pass_name.c_str())); return pass_version_checkers_map_[pass_name]; } bool IsPassCompatible(const std::string& fuse_pass_name) const { -- GitLab