未验证 提交 477b0c46 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix op version checker of pass bug (#30028) (#30084)

上级 84c2315a
...@@ -327,9 +327,8 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -327,9 +327,8 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(quant_conv2d_dequant_fuse_pass, REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass); paddle::framework::ir::QuantDequantFusePass);
REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
...@@ -338,5 +337,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) ...@@ -338,5 +337,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.EQ("fake_quantize_abs_max", 0) .EQ("fake_quantize_abs_max", 0)
.EQ("fake_quantize_range_abs_max", 0) .EQ("fake_quantize_range_abs_max", 0)
.EQ("fake_quantize_moving_average_abs_max", 0) .EQ("fake_quantize_moving_average_abs_max", 0)
.EQ("fake_channel_wise_quantize_abs_max", 0) .LE("fake_channel_wise_quantize_abs_max", 1)
.EQ("fake_dequantize_max_abs", 0)); .EQ("fake_dequantize_max_abs", 0));
...@@ -240,7 +240,13 @@ class OpVersionComparator { ...@@ -240,7 +240,13 @@ class OpVersionComparator {
if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \ if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \
version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \ version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \
} \ } \
return version_id cmp_math target_version_; \ bool check_ok = version_id cmp_math target_version_; \
if (!check_ok) { \
LOG(WARNING) << "Check op version in pass failed. op name:" \
<< op_name_.c_str() << " op_version:" << version_id \
<< " target_version:" << target_version_; \
} \
return check_ok; \
} \ } \
virtual ~OpVersion##cmp_name##Comparator() {} \ virtual ~OpVersion##cmp_name##Comparator() {} \
\ \
...@@ -326,6 +332,11 @@ class PassVersionCheckerRegistrar { ...@@ -326,6 +332,11 @@ class PassVersionCheckerRegistrar {
return instance; return instance;
} }
PassVersionCheckers& Register(const std::string& pass_name) { PassVersionCheckers& Register(const std::string& pass_name) {
PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name),
pass_version_checkers_map_.end(),
platform::errors::AlreadyExists(
"PassVersionCheckers(%s) has alredy been registered.",
pass_name.c_str()));
return pass_version_checkers_map_[pass_name]; return pass_version_checkers_map_[pass_name];
} }
bool IsPassCompatible(const std::string& fuse_pass_name) const { bool IsPassCompatible(const std::string& fuse_pass_name) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册