From f3ba7253ae0b45fed4c25a226cf5d696e89bc13d Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 3 Dec 2020 19:00:32 +0800 Subject: [PATCH] Trt del pass (#29316) * cherry-pick a22ea652cf214d8e5a4d41fe48e615f14c5ecb49 * fix analysis_config bug. (#29304) * fix code format Co-authored-by: Wilber --- paddle/fluid/inference/api/analysis_config.cc | 24 ++++++++++++++----- paddle/fluid/inference/api/api_tester.cc | 13 ++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 9d86eef7f15..3691df73ef6 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -170,12 +170,24 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #undef CP_MEMBER - // Update(); - // Update() will reset all the passes, when some tensorRT pass is deleted in - // other.pass_builder(), it will set again, so just copy the passes. - pass_builder_->ClearPasses(); - for (const std::string &pass : other.pass_builder()->AllPasses()) { - pass_builder_->AppendPass(pass); + Update(); + if (use_tensorrt_) { + // Update() will reset all the passes, when some tensorRT pass is deleted in + // other.pass_builder(), it will set again, so we just remove the + // deleted_pass. + auto all_passes = kTRTSubgraphPasses; + auto other_passes = other.pass_builder()->AllPasses(); + // We should sort them, because the user may call the SwitchIrDebug + // interface, which will change the pass. + std::sort(all_passes.begin(), all_passes.end()); + std::sort(other_passes.begin(), other_passes.end()); + std::vector deleted_passes; + std::set_difference(all_passes.begin(), all_passes.end(), + other_passes.begin(), other_passes.end(), + std::inserter(deleted_passes, deleted_passes.begin())); + for (auto ps : deleted_passes) { + pass_builder_->DeletePass(ps); + } } } diff --git a/paddle/fluid/inference/api/api_tester.cc b/paddle/fluid/inference/api/api_tester.cc index 2c450ef7cea..80ec84c8d95 100644 --- a/paddle/fluid/inference/api/api_tester.cc +++ b/paddle/fluid/inference/api/api_tester.cc @@ -67,4 +67,17 @@ TEST(paddle_inference_api, get_version) { ASSERT_FALSE(version.empty()); } +TEST(paddle_inference_api, AnalysisConfigCopyCtor) { + AnalysisConfig cfg1; + cfg1.EnableUseGpu(10); + cfg1.EnableTensorRtEngine(); + std::string delete_pass("skip_layernorm_fuse_pass"); + cfg1.pass_builder()->DeletePass(delete_pass); + AnalysisConfig cfg2(cfg1); + + auto passes = cfg2.pass_builder()->AllPasses(); + for (auto ps : passes) { + CHECK_NE(ps, delete_pass); + } +} } // namespace paddle -- GitLab